Perceiver: General Perception with Iterative Attention, Jaegle, Gimeno, Brock, Zisserman, Vinyals, Carreira; 2021 - Summary
author: | ishank-arora |
score: | 9 / 10 |
What is the core idea?
-
Inductive biases are baked into models, thus while increasing efficiency, they limit the model’s performance to one particular modality.
- Enter Perceiver, a model designed to handle arbitrary configurations of different modalities.
-
Perceiver is built upon Transformers
-
By introducing a small set of latent units, the authors are able to introduce an attention bottleneck through which inputs must pass.
- Eliminates the quadratic scaling problem of the classical Transformer
- Decouples network depth from input’s size
-
Allows scaling to hundreds of thousands of inputs like ConvNets
- Employs asymmetric attention mechanism to iteratively distill inputs
- Hence, can scale to handle very large inputs
- Model is competitive or outperforms strong specialised models across various modalities
How is it realized (technically)?
-
- Architecture
- Cross attention module - maps a byte array and a latent array to a latent array
- A byte array can be something like a pixel array etc.. (the input)
- Generally large size M
- Size of latent array is typically much smaller, and is a hyperparameter - size N
- So, M<<N.
- A byte array can be something like a pixel array etc.. (the input)
- Transformer tower maps a latent array to a latent array
- In essence, the cross-attention module projects a high-dimensional input byte array to a fixed-dimensional latent bottleneck
- Then processes it using a deep stack of Transformer-style self-attention blocks in the latent space
- Cross attention and the Transformer are applied in alternation
- Allows Perceiver to iteratively attend to input byte array by alternating between cross attention and the latent self-attention blocks.
- Asymmetric attention
- Authors use the query-key-value (QKV) attention
- Problem with QKV in Transformers is quadratic self-attention in input dimensionality
- In standard QKV attention, Q∈RM×D,K∈RM×C, and V∈RM×C where C and D are channel dimensions and M is the input size.
- So softmax(QKT)V is O(M2)
- Introduce assymmetry
- Q is a projection of the learned latent array with index dimension N<<M, where N is a hyperparameter.
- K and V are projections of the inputer byte array
- Resulting cross attention operation is O(MN)
- Problem with QKV in Transformers is quadratic self-attention in input dimensionality
- Authors use the query-key-value (QKV) attention
- Cross attention module takes the shape of input Q
- Induces a bottleneck
- which the authors exploit by building deep Transformers in the latent space - low cost of O(N2).
- Hence latent Transformer has complexiity O(LN2), when considered as a function of the number of layers.
- notice that a standard Transformer would be O(LM2)
- Overall architecture - O(MN+LN2)
- Input size (M) and depth (L) are decoupled
- Allows construction of very large networks on large-scale data
- Input size (M) and depth (L) are decoupled
- Latent transformer uses GPT-2 architecture
- based on the decoder of the orginal Transformer
- Iterative cross-attention
- Severity of the bottleneck may restrict the network’s ability to capture all of the necessary details from the input signal
- Hedge against this by adding multiple cross-attend layers
- Allows latent array to iteratively extract information
- Hedge against this by adding multiple cross-attend layers
- Severity of the bottleneck may restrict the network’s ability to capture all of the necessary details from the input signal
- Cross attention module - maps a byte array and a latent array to a latent array
- Position encodings
- Permutation invariance and position information
- Attention is a permutation-invariant operation which is preserved by the Perceiver model.
- However permutation invariance does not allow exploitation of spatial relationships in the input data.
- Hence, they also include position encodings
- Scalable Fourier Features
- Authors use a parameterization of Fourier features
- allows direct representation of the position structure of the input data
- allows control over the number of frequency bands in the position encoding indepdently of the cutoff frequency
- uniformly sample all frequencies up to a target resolution.
- Frequency parametrized to take the values [sin(fk\pixd),cos(fk\pixd)],
- where fk is the kth band of a bank of frequencies spaced equally between 1 and μ2. μ2 can be naturally interpreted as the Nyquist frequency, where μ is the target sampling rate.
- xd is the value of input position along the dth dimension - for images d=2. xd∈[−1,1].
- related to the NeRF position encoding scheme (Mildenhall et al., 2020).
- Authors use a parameterization of Fourier features
- Position encodings are generally applicable.
- Feature based approach allows the network to learn how to use (or ignore) the position structure
- Position encoding can be easily adapted to a new domain as Fourier features are trivial to adapt as long as the input dimension is relatively small and known
- Position encodings can be extended to multimodal data - each domain has it’s own position encoding.
- Permutation invariance and position information
How well does the paper perform?
- Pretty competitively with other models
- Top-1 validation accuracy (in %). Compares baseline from literature, baseline with Fourier Features, Transformer with Fourier Features against Perceiver. Perceiver performs as well as baseline.
- Top-1 validation accuracy (in %). Shows the Perceiver model’s resiliency to permuted dataset. This is due to global attention. Input RF refers to the Input Receptive Field (RF) in pixels - this is the size of the local neighbourhood used for 2D convolutions
- Very comparable results. In future work, hope to incorporate the insights that are also used in CNN-14 (Kong et al., 2020).
- Top-1 validation accuracy (in %) on Model-Net40. PointNet++ uses extra geometric features and other techniques not used in the blue models.
TL;DR
- Uses iterative asymmetric cross-attention to decouple network depth from input size allowing for large networks
- Motivated to be a generalisable architecture - and thus, Is a multimodal architecture
- Performs competitively against specialised models with inductive biases baked in.