An Approximation Algorithm for Optimal Subarchitecture Extraction

Tags: paper ml
State: None
Code: None

Summary #todo


Optimal Sub-architecture Extraction

Select the best non-trainable parameters for a NN such that it is optimal w.r.t parametrize size, inference speed and error rate.

Class of networks that satisfy the following three conditions:

  1. Intermediate layers are more expensive in terms of param size and number of ops than I/O functions
  2. Optimization problem is L-Lipchitz smooth with bounded stochastic gradients
  3. Training procedure uses SGD

Above assumptions are labelled as the ABnC\text{AB}^nC property

Assumption: optimization problem is u-strongly convex then the algorithm is a FPTAS with approximation ratio of p1ϵp \leq | {1 - \epsilon} |

Can be seen as an Architecture Search Problem, where the architecture remains fixed but the non-trainable parameters do not.


"Optimal Sub-architecture Extraction"

  • Find set of non-trainable params for deep NN RpRq\mathbb{R}^p \to \mathbb{R}^q
  • layer is a non-linear function: li(x,Wi)l_i(x, W_i) takes input xRipx \in \mathbb{R}^p_i and a finite set of trainable weights, Wi={wi,1,...wi,k}W_i = \{w_{i,1}, ... w_{i,k}\}; every wi,jWiw_{i,j} \in W_{i} is an r-dimensional array (vector)
  • Supervised dataset D=(xi,yi)i=1,...,mD = {(x_i, y_i)}_{i=1,...,m}, unknown probability distribution. Search space is SS (Ξ in the paper).
  • Set of possible weak assignments WW, set of hyper-parameter combinations θ\theta and architecture f(x)=ln(ln1(...l1(x;W;ξ1)...;Wn1;ξn1),Wn,ξn)f(x) = l_n(l_{n-1}(...l_1(x;W;\xi_1)...; W_{n-1}; \xi_{n-1}), W_{n}, \xi_{n})

Surrogates the objective values

Inference Speed

Sum of for each layer i:

  • Number of additions
  • Number of multiplications
  • Number of other operations in layer i

Related Work

  • NAS
  • Weight Pruning
  • General neural network compression techniques


  • Behaves like an FPTAS
  • Runs in O(E+WT(1+BE)/ϵs3/2)O(|E| + |W_T^{*}(1 + |B| |E|) / \epsilon s^{3/2})

Search Space

FPTAS search algorithm

Hparams it searches over:

  • Attention heads (A)
  • Encoder layers (D)
  • Hidden size (H)
  • Intermediate layer size (I)

BERT-base has D=12, A=12, H=768, I=3072

They search over:

Number of Attention HeadsA{4, 8, 12, 16}
Number of Encoder LayersD{2, 4, 6, 8, 10, 12}
Hidden SizeH{512, 768, 1024}H must be divisble by A
Intermediate Layer SizeI{256, 512, 768, 1024, 3072}

Ignoring configurations where H is not divisible by A