Tags: paper ml
State: None
Source: https://arxiv.org/abs/2010.08512
Code: None
Summary #todo
Problem
Optimal Sub-architecture Extraction
Select the best non-trainable parameters for a NN such that it is optimal w.r.t parametrize size, inference speed and error rate.
Class of networks that satisfy the following three conditions:
- Intermediate layers are more expensive in terms of param size and number of ops than I/O functions
- Optimization problem is L-Lipchitz smooth with bounded stochastic gradients
- Training procedure uses SGD
Above assumptions are labelled as the property
Assumption: optimization problem is u-strongly convex then the algorithm is a FPTAS with approximation ratio of
Can be seen as an Architecture Search Problem, where the architecture remains fixed but the non-trainable parameters do not.
OSE
"Optimal Sub-architecture Extraction"
- Find set of non-trainable params for deep NN
- layer is a non-linear function: takes input and a finite set of trainable weights, ; every is an r-dimensional array (vector)
- Supervised dataset , unknown probability distribution. Search space is (Ξ in the paper).
- Set of possible weak assignments , set of hyper-parameter combinations and architecture
Surrogates the objective values
Inference Speed
Sum of for each layer i:
- Number of additions
- Number of multiplications
- Number of other operations in layer i
Related Work
- NAS
- Weight Pruning
- General neural network compression techniques
Complexity
- Behaves like an FPTAS
- Runs in
Search Space
FPTAS search algorithm
Hparams it searches over:
- Attention heads (A)
- Encoder layers (D)
- Hidden size (H)
- Intermediate layer size (I)
BERT-base has D=12, A=12, H=768, I=3072
They search over:
Name | Variable | Values | Description |
---|---|---|---|
Number of Attention Heads | A | {4, 8, 12, 16} | |
Number of Encoder Layers | D | {2, 4, 6, 8, 10, 12} | |
Hidden Size | H | {512, 768, 1024} | H must be divisble by A |
Intermediate Layer Size | I | {256, 512, 768, 1024, 3072} |
Ignoring configurations where H is not divisible by A