MobileNet-V4 (now in timm)

Community Article Published June 17, 2024

History

5 years ago now MobileNet-V3 (https://arxiv.org/abs/1905.02244) and EfficientNet (https://arxiv.org/abs/1905.11946) vision models were introduced by Google researchers. At a very early stage in timm's development, I set out to reproduce these model architectures and port the originally released Tensorflow model weights into PyTorch.

Both of these model architectures were based on the Inverted Residual Block (also called Inverted Bottleneck) that was introduced in the earlier MobileNet-V2 model. The IR consists of a 1x1 pointwise (PW) expansion convolution, followed by a depthwise (DW) convolution (3x3 or 5x5), and finally a 1x1 pointwise linear (PWL, no activation) convolution in the residual path. Unlike most other residual blocks at the time, the wide part is in the middle of the block at the depthwise conv, instead of at the start/end, and the output of the block has no activation (linear output), hence 'Inverted'.

To this day, timm remains the most comprehensive collection of known EfficientNet and MobileNet-V3 architectures. It covers all of the officially released Tensorflow weights from various model papers (EfficientNet, EfficientNet-EdgeTPU, EfficientNet-V2, MobileNet-V2, MobileNet-V3), training techinques (RandAug/AutoAug, AdvProp, Noisy Student), and numerous other closely related architectures and weights such as MNasNet, FBNet v1/v2/v3, LCNet, TinyNet, MixNet. There are also many weights trained in timm, pure PyTorch, with PyTorch friendly convolution padding (TF weight ports use a 'SAME' padding emulation) that aren't in other collections.

New Models

Now, it's finally time for MobileNet-V4 (https://arxiv.org/abs/2404.10518). Reading through the paper it's apparent the goal was to come up with a new set of NAS searched models that are runtime optimal on today's mobile/edge hardware, from small DSP/CPU devices to modest edge accelerators (e.g. EdgeTPU) in current mobile phones.

This goal was achieved by introducing two new block types to the previous mix:

  • Universal Inverted Bottleneck (UIB)
  • Multi Query Attention (MQA)

Universal Inverted Bottleneck

A superset of the original Inverted Residual / Inverted Bottleneck block, the UIB adds more flexibility in the search space, adding 2 extra depthwise convolution positions at the start and end of the block, and making the middle depthwise convolution optional. The extra final convolution isn't used (yet), but the new blocks in use now included:

  • 'ExtraDW' with a 3x3 or 5x5 DW convolution to start the block in front of existing 1x1 PW + kxk DW + 1x1 PWL pattern
  • 'FFN' with no DW convs enabled and just the 1x1 PW expansion + linear convs
  • 'ConvNeXt' with 3x3 or 5x5 DW convolution to start the block, no middle DW convolution, so a 1x1 + 1x1 FFN to end

image/png https://arxiv.org/abs/2404.10518

Mobile MQA

Also added for 'Hybrid' variants of the MobileNet-V4 is attention via a mobile optimized Multi Query Attention module. Neither the key or value have any heads, just 4 or 8 heads for the query. There is optional 2D spatial downsampling for the key-value and/or query.

PyTorch Implementation

I've recently implemented these models in timm in a bid to keep timm the best place to go for efficient image encoders. It builds on top of the previous MobileNet-V3 implementation. I have trained a number of initial weights and am working on covering all of the models mentioned in the paper: https://huggingface.co./collections/timm/mobilenetv4-pretrained-weights-6669c22cda4db4244def9637

And in case you look at that PR and wonder, WTH is EfficientNet-X / EfficientNet-H? They are little known variants of those models w/ Space2Depth, tweaked for TPU or GPU use. That's there too but not the focus.

The is an official Tensorflow implementation of these models in the Tensorflow Model Garden (https://github.com/tensorflow/models), but no sign of weights yet.

A comparison of paper ImageNet-1k training results vs timm in tables below. Note that params in paper assume folding of normalization params into convs, timm values are in training state.

timm:

model top1 top1_err top5 top5_err param_count img_size
mobilenetv4_hybrid_large.e600_r384_in1k 84.266 15.734 96.936 3.064 37.76 448
mobilenetv4_hybrid_large.e600_r384_in1k 83.800 16.200 96.770 3.230 37.76 384
mobilenetv4_conv_large.e600_r384_in1k 83.392 16.608 96.622 3.378 32.59 448
mobilenetv4_conv_large.e600_r384_in1k 82.952 17.048 96.266 3.734 32.59 384
mobilenetv4_conv_large.e500_r256_in1k 82.674 17.326 96.31 3.69 32.59 320
mobilenetv4_conv_large.e500_r256_in1k 81.862 18.138 95.69 4.31 32.59 256
mobilenetv4_hybrid_medium.e500_r224_in1k 81.276 18.724 95.742 4.258 11.07 256
mobilenetv4_conv_medium.e500_r256_in1k 80.858 19.142 95.768 4.232 9.72 320
mobilenetv4_hybrid_medium.e500_r224_in1k 80.442 19.558 95.38 4.62 11.07 224
mobilenetv4_conv_blur_medium.e500_r224_in1k 80.142 19.858 95.298 4.702 9.72 256
mobilenetv4_conv_medium.e500_r256_in1k 79.928 20.072 95.184 4.816 9.72 256
mobilenetv4_conv_medium.e500_r224_in1k 79.808 20.192 95.186 4.814 9.72 256
mobilenetv4_conv_blur_medium.e500_r224_in1k 79.438 20.562 94.932 5.068 9.72 224
mobilenetv4_conv_medium.e500_r224_in1k 79.094 20.906 94.77 5.23 9.72 224
mobilenetv4_conv_small.e2400_r224_in1k 74.616 25.384 92.072 7.928 3.77 256
mobilenetv4_conv_small.e1200_r224_in1k 74.292 25.708 92.116 7.884 3.77 256
mobilenetv4_conv_small.e2400_r224_in1k 73.756 26.244 91.422 8.578 3.77 224
mobilenetv4_conv_small.e1200_r224_in1k 73.454 26.546 91.34 8.66 3.77 224

Paper: image/png