A minimal ablation study of the proposed contributions in the latest High-Performance Large-Scale Image Recognition Without Normalization paper.

Over the past month or so, many in the computer vision community have been digesting a paper on NFNets that recently achieved state-of-art performance doing classification on ImageNet, "High-Performance Large-Scale Image Recognition Without Normalization". We wanted to take a moment to dig into the research by Brock, De, Smith, and Simonyan (2021) here, talk about the advantages (and disadvantages) of batch normalization, and just mostly roll up our sleeves and talk about this really fascinating research for a few hundred words.

Of course, if you'd like to dig into the source material (and we do recommend that!), you can find those links below. Without further ado:

Batch normalization (BN) was partially responsible for the growth of deep learning by enabling the training of deeper neural networks. Originally proposed in 2015, adding batch normalization helps normalize the hidden representations learned during training (i.e., the output of hidden layers) in order to address internal covariate shift.

Note however BN has nothing to do with the internal covariate shift. In the paper titled, "How Does Batch Normalization Help Optimization?" by Santurkar et al. (2018) the authors uncovered that BN has a more fundamental impact on training: "it makes the optimization landscape significantly smoother. This smoothness induces a more predictive and stable behavior of the gradients, allowing for faster training."

ðŸ“Œ Note: The experimental results shown below were produced by training a custom convolutional-based neural network on the CIFAR-10 dataset. Check out the linked colab notebooks for implementation details.

No matter the actual reason why BN works, there are many practical benefits of training a model with BN. For starters:

- Models trained with BN converge quickly with better test accuracy.

- BN also enables model training with a large learning rate.

There are a few additional benefits to batch normalization. Namely:

- BN allows efficient large-batch training.
- BN eliminates mean shift. Activation functions like ReLU and GeLU are non-symmetric thus have non-zero mean activation. This introduces mean-shift. Batch normalization ensures the mean activation on each channel is zero across the current batch, eliminating the mean shift.
- BN has a regularization effect. (Source: Towards Understanding Regularization in Batch Normalization by Luo et al.)
- BN smooths loss landscape. (Source: "How Does Batch Normalization Help Optimization?" by Santurkar et al.)

Now, while batch normalization is a key component of most image classification models, it does come with some undesirable properties. The research community does try to find a way around these, but in the long run, it might be more better if we found an alternative to batch normalization instead of dealing with its idiosyncrasies and downsides.

In fact, let's cover a few of the disadvantages of batch normalization:

- It incurs memory overhead. (Source: In-Place Activated BatchNorm for Memory-Optimized Training of DNNs)
- Batch Normalization increases the time to evaluate gradient in some networks.
- Discrepancies between training and inference score if BN is not used carefully.
- BN can break the independence between training examples in the minibatch. Additionally, because of this particular issue:
- It's hard to reproduce the results on different hardware.
- You can run into subtle implementation errors especially in distributed training. For this reason, Synchronized Batch Normalization was proposed by Zhang et al. in Context Encoding for Semantic Segmentation.
- And since batch statistics are computed while training, which can be seen as an interaction between training examples, networks can "cheat" certain loss functions. This is a major concern for sequence modeling tasks, which has driven language models to adopt alternative normalizers.

- Moreover, networks can also degrade if the batch statistics have a large variance during training.

ðŸ“Œ Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.

- Lastly, the performance of batch normalization is sensitive to the batch size. The parallel coordinate plot below shows that the final test accuracy is sensitive to the batch size. There's a negative correlation between the batch size and test accuracy (in other words: high batch size leads generally to lower test accuracy).

ðŸ“Œ Note: "Large-batch training does not achieve higher test accuracies within a fixed epoch budget (Smith et al.,2020), it does achieve a given test accuracy in fewer parameter updates, significantly improving training speed when parallelized across multiple devices."

From the benefits of BN, we are aware of the good ingredients that are required for a high-performing neural network. A workable alternative to BN should bring us most (if not all) of these benefits while also mitigating the disadvantages we spelled out above.

Now, previous works have attempted to train deep ResNets to competitive accuracies without normalization by recovering just one or two benefits of BN. The key idea used in those works is to suppress the scale of the activations on the residual branch at initialization by introducing a small constant or learnable scalars.

Normalizer-Free ResNets (NF-ResNets) were first proposed in a paper titled "Characterizing Signal Propagation to close the performance gap in Unnormalized ResNets" by Brock et al. (2021). NF-ResNets are a class of pre-activation ResNets that can be trained to competitive training and test accuracies without normalization layers.

ï»¿ï»¿If you want to get up to speed with a ResNet's architecture, here's a nice video summary of the paper by Yannic. But how is NF-ResNet different than good old Resnet?

- NF-ResNet employs a residual block of the form h_{i=1} = h_i + Î±f_i(h_i/Î²_i), where:

- h_i and â€‹ï»¿h_{i+1} are inputs to the i^{th}ï»¿ residual branch and the resulting output respectively. h_{i+1} â€‹ï»¿ is the input for the next residual block.
- f_i is parameterized to be variance preserving function at initialization such that Var(f_i(z)) = Var(z).
- Î± is a scalar that specifies the rate at which the variance of the activation increases after each residual block.
- Î² is the standard deviation of the inputs to the i^{th} layer residual block.

2. NF-ResNet uses Scaled Weight Standardization. Weight Standardization reparameterizes (W_{i,j} \rightarrow \hat W_{i, j}) the convolutional layer such that,

\hat W = (W_{i,j} - Î¼_i) / Ïƒ_i, where:

- \hat W and Ware the reparameterized and original weights respectively,
- Î¼_i = (1/N) \sum_{j=1}^N W_{i,j} and
- Ïƒ_i^2 = (1/N) \sum_{j=1}^N (W_{i, j}^2 - Î¼_i^2)

Standard Weight Standardization is a minor modification to Weight Standardization, where, Ïƒ_i^2 = (1/N) \sum_{j=1}^N (W_{i, j} - Î¼_i)^2.

3. The activation functions are also scaled by a non-linearity specific scalar gain Î³.

ðŸ“Œ Fact: "With additional regularization (Dropout and Stochastic Depth), NF-ResNets match the test accuracies achieved by batch normalized pre-activation ResNets on ImageNet at batch size 1024. They also significantly outperform their batch normalized counterparts when the batch size is very small, but they perform worse than batch normalized networks for large batch sizes (4096 or higher)."

ðŸ“Œ Note: "NF-ResNets do not match the performance of state-of-the-art networks like EfficientNets which use Batch Normalization."

In the end, NF-ResNet-50 outperforms the good old ResNet-50 by a margin of approximately ~15%. (I have used Ross Wightman's timm package for both the model definitions.)

One problem is that NF-ResNet could not scale to large batch sizes (4096 or higher) for training. The authors of the paper we're concerned with today (High-Performance Large-Scale Image Recognition Without Normalization) hypothesized that gradient clipping should help scale NF-ResNets to a larger batch. To this end, they proposed Adaptive Gradient Clipping (AGC).

A standard clipping algorithm clips the gradient before updating the parameter Î¸ such that:

G \rightarrow\left\{\begin{array}{ll}
\lambda \frac{G}{\|G\|} & \text { if }\|G\|>\lambda \\
G & \text { otherwise }
\end{array}\right.

Here, \lambda (clipping threshold) is the hyperparameter to be tuned.

Gradient clipping can help train at a higher learning rate but is quite sensitive to the clipping threshold (evident from the media panel below).

AGC tries to overcome this issue by introducing adaptive clipping instead of "hard" clipping.

ðŸ“Œ Note: "The AGC algorithm is motivated by the observation that the ratio of the norm of the gradients G^l to the norm of the weights W^l of layer l, \frac{\left\|G^{\ell}\right\|_{F}}{\left\|W^{\ell}\right\|_{F}} , provides a simple measure of how much a single gradient descent step will change the original weights W^l."

The authors in this paper are using unit-wise ratios of gradient norms to parameter norms instead of layer-wise norm ratios. Here, a Frobenius Norm (||\space.\space||_F) is used. The AGC is given such that:

G_{i}^{\ell} \rightarrow\left\{\begin{array}{ll}
\lambda \frac{\left\|W_{i}^{\ell}\right\|_{F}^{\star}}{\left\|G_{i}^{\ell}\right\|_{F}} G_{i}^{\ell} & \text { if } \frac{\left\|G_{i}^{\ell}\right\|_{F}}{\left\|W_{i}^{\ell}\right\|_{F}^{\star}}>\lambda \\
G_{i}^{\ell} & \text { otherwise. }
\end{array}\right.

where, \lambda is the hyperparameter and ||W_i||_F^* = max(||W_i||_F, Îµ). A small value of Îµ = 10^-3 prevents zero-initialized parameters from always having their gradients clipped to zero.

ðŸ“Œ Fact: "Using AGC, we can train NF-ResNets stably with larger batch sizes (up to 4096), as well as with very strong data augmentations like RandAugment."

In order to understand AGC well and see the effect in model training, I set up a bunch of experiments. The intent of these experiments are not to conclude anything serious but rather to explore. The experiments cover a tiny fraction of the possible experimentation configurations.

Here, AGC stabilizes NF-ResNets for a larger-batch configuration. But larger batch size does not lead to higher test accuracies. To this end, I wanted to investigate the effect of AGC and if it should be considered a viable alternative to good old gradient clipping. So how does AGC do?

First, a note about the three models I trained:

- A ResNet-20 architecture on the CIFAR-10 dataset with Batch Normalization (Colab Notebook).
- A ResNet-20 architecture without the Batch Normalization layer with the same configuration (Colab Notebook).
- ResNet-20 architecture without the Batch Normalization and with AGC (Colab Notebook).

ðŸ“Œ Note: The results are from training with batch size 1024 with the clipping factor of 0.01.

- AGC could not produce a model comparable to the baseline.
- The test accuracy with AGC is approximately the same as that of one trained without Batch Normalization.
- The result correlated with the choice of the clipping factor. That said, we probably shouldn't conclude a lot from this experiment and I highly suggest investigating a wider configuration space before making any broad claims here.

In my opinion the clipping factor of 0.01 might be very tight. I encourage readers to experiment with larger clipping factors.

The clipping factor for regular gradient clipping is sensitive to batch size, model depth, learning rate, etc. I wanted to investigate the relationship between batch size and clipping factor and their correlation with the final test accuracy.

Using Weights and Biases Sweep I was able to quickly set up my ablation study.

ðŸ“Œ Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.

- Batch size has a negative correlation with test accuracy. This also validates the theory that increasing batch sizes might not lead to better test accuracy.
- On the other hand, the clipping factor has a positive correlation. Thus in this configuration space, increasing the clipping factor should lead to higher test accuracy.
- However, it's hard to know what exactly to make out of the relationship between the two parameters. Some that I noticed from this experiment and might not hold true for all cases are:
- Small batch size requires a bigger clipping factor.
- Bigger batch sizes tend to work better with smaller clipping factors.

- Out of the two hyperparameters, the batch size is more important towards higher test accuracy.

The authors performed many ablation studies of their own to showcase the effectiveness of AGC on NF-ResNets. These are the key findings:

- The benefits of using AGC are smaller when the batch size is small.
- Smaller clipping thresholds are necessary for stability at higher batch sizes. This is something we saw from my experiments as well.
- It's better to not clip the final linear layer or the output layer.
- It's possible to train NF-ResNet without clipping the gradients for the initial convolutional layers.

(Source)

ðŸŽ‰ SOTA Warning: Our NFNet-F1 model achieves comparable accuracy to an EffNet-B7 while being 8.7X faster to train. Our NFNet-F5 model has similar training latency to EffNet-B7, but achieves a state-of-the-art 86.0% top-1 accuracy on ImageNet.

ðŸ“Œ Note: NF-ResNets and NF-Nets are two different architectures.

Neural network architecture design depends on the choice of metric to optimize. These metrics can be:

- FLOPs count,
- Inference latency on a target device like edge/mobile devices,
- Training latency on an accelerator.

The authors of this paper decided to optimize training latency on existing accelerators. They explored the model space by manually searching through the trends that improved top-1 accuracy on ImageNet against actual training latency on the device. The aim was to maximize on both fronts.

The verdict? The authors achieved the state-of-the-art result of 86.5% top-1 test accuracy on ImageNet by training NFNet-F6 with the recently proposed Sharpness Aware Minimization (SAM) technique. The official implementation of the architecture is in the JAX framework. You can find the official GitHub repo here.

Additionally, Yannic Kilcher did an amazing explanation of NFNets in his YouTube video. Check it out here:

I have provided a colab notebook to train an NF-Net model using PyTorch Lightning on the Caltech-101 dataset. The NF-Net implementation is based on timm package. Feel free to change model variants and other hyperparameters.

The experiments are conducted on the CIFAR-10 dataset which might not be quite enough for a technique to show its effects. But I limited my experimental setup to Colab Notebooks and Kaggle Kernels for ease of use.

I want to give shoutouts to these works which enabled me to compile this report with various code examples and experimental results:

- Adaptive Gradient Clipping in TensorFlow by Sayak Paul.
- Some other notable implementation of normalizer free networks: nfnets-pytorch
- Sharpness Aware Minimization in TensorFlow by Sayak Paul.

Finally congrats to the authors on this amazing work.

I would also like to thank Sayak Paul and Morgan McGuire for their feedbacks which helped me to improve this report. Additional thanks to Justin Tenuto for his editorial magic.