author: | zayne-sprague |
score: | 7 / 10 |
The Big Idea
By splitting images into chunks and putting the chunks in a sequence, we can utilize transformers for image classification without heavy attention costs.
Why can’t we use transformers directly on images?
- Attention for an image would require attending every pixel to every other pixel, this is too costly to do for semi-large images
How can we avoid the large cost?
- The main intuition is that by slicing an image into chunks, we can create a sequence of small image patches to be used in the transformer.
How is it realized?
We can create chunks of an image into a small size (16x16 for example) and feed in each chunk as part of a sequence to the transformer
Each chunk will be turned into a vector (flatten the channels and pixels), and passed into a linear projection (patch embedding)
Each patch embedding will also receive a position encoding sot he model knows where in the image the patch came from
From here, we are able to use the transformer similar to BeRT and others (as shown below)
Math on Image Patches
Image Definition
- \(x \in \mathbb{R}^{H x W x C}\) where H = height, W = width, C = channels of input images, and \(x_p\) is a sampled image
Patch Definiton
- \(x_p \in \mathbb{R}^{Nx(P^2xC)}\) where P = size of image patch, N = \(HW/P^2\) (number of patches depends on patch size and size of the image)
Does it work?
Yep! With sufficiently sized models, ViT out performs CNN based models significantly
Multiple variants were tested, mostly varying the size of the model to be comparable with others (BiT) but also to see how good the model could get.
The tables below show comparisons of classification accuracies on various benchmarks across models, as well as outline the other variants tested.
at 2% of the training time, ViT-L is able to achieve competitive performance to BiT
at 7% ViT-L outperforms BiTL in nearly every category
Note: JFT-300M is a dataset google made with 300million images. (This was and I think currently is still privately held)
Why does it do so well?
- ViT Learns the basic CNN building blocks, but better!
Far left is the linear projection layer, they look CNN convolution kernels
Middle are the position embedding, not only do they have high weight at their corresponding location, they are also relating to patches nearby
Far right, the Transformer learns to utilize large receptive fields (like CNNs do) as the layers increase
So why not use ViT everywhere, where’s the catch?
It comes down to the inductive bias of CNNs vs Transformers
CNNs are built to find structural information (their inductive bias), where as a transformer does not have this “intuition”
This means a transformer must be trained to pick up on structural information within images
which means ViT fails to outperform CNNs until it receives sufficient data (data hungry / requires large datasets to train)
Smaller datasets can actually cause the transformer to overfit, so it might be worth using a CNN in this case.
Why not combine CNNs & Transformers (other variants)
You can processes images with a CNN and get a feature map rather than slicing images into patches, these models were called Hybrids in the paper.
Once you have the feature maps, you flatten them for feature vectors and apply the patch embedding to them (everything after this is the same)
For smaller models with less training time, hybrid models outperform Transformers and Resnets, but as the models get larger – transformers outperform.
- Splitting images into chunks and then embedding them as vectors allows us to use the Transformer Architecture (ViT)
- ViT learns kernel like filters and receptive fields like CNNs after sufficient training and can outperform CNN baselines
- CNNs are still better to use for small datasets