A simple implementation of the attention mechanism in JAX
I recently started learning JAX, and although I've been working as an ML Engineer for some time now, my primary experience was with PyTorch and TensorFlow. In my free time, I decided to dive into JAX, and what better way to understand it than by implementing concepts I already know well? In this tutorial, I walk you through a simple implementation of the attention mechanism—a concept I've studied and used over the years. Here, you'll find detailed explanations of both single-head and multi-head attention using JAX and Flax, along with performance benchmarking using JIT compilation. There might be mistakes and improvements might be possible, but consider this as me learning along with you all.
1. Single-Head Attention
In this section, we create a single-head attention module. The module first uses dense (linear) layers to transform the input encodings into three representations: queries, keys, and values. Then, it computes the dot product between queries and keys (after swapping axes for proper alignment) to measure similarity. These similarities are scaled to avoid issues during training, and an optional mask can be applied to ignore certain positions. Finally, a softmax converts these scores into probabilities, which are used to combine the value vectors.
import jax
import jax.numpy as jnp
from flax import linen as nn
class Attention(nn.Module):
d_model: int = 2
row_dim: int = 0
col_dim: int = 1
@nn.compact
def __call__(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
# Create dense layers (without bias) to generate queries, keys, and values.
W_q = nn.Dense(features=self.d_model, use_bias=False, name="W_q")
W_k = nn.Dense(features=self.d_model, use_bias=False, name="W_k")
W_v = nn.Dense(features=self.d_model, use_bias=False, name="W_v")
# Project the input encodings into queries, keys, and values.
q = W_q(encodings_for_q)
k = W_k(encodings_for_k)
v = W_v(encodings_for_v)
# Swap axes of the key tensor to align dimensions for matrix multiplication.
k_t = jnp.swapaxes(k, self.row_dim, self.col_dim)
# Compute dot products between queries and the transposed keys.
sims = jnp.matmul(q, k_t)
# Scale the similarity scores by the square root of the key dimension.
scale = jnp.sqrt(k.shape[self.col_dim])
scaled_sims = sims / scale
# If a mask is provided, apply it to ignore specified positions.
if mask is not None:
scaled_sims = jnp.where(mask, -1e9, scaled_sims)
# Apply softmax to convert similarity scores into attention probabilities.
attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
# Use the attention weights to compute a weighted sum of the value vectors.
attention_scores = jnp.matmul(attention_percents, v)
return attention_scores
Dense Layers for Projections:
Three dense layers are used to compute the queries (q), keys (k), and values (v) from the input data. These projections are essential as they transform the input into different subspaces.Axis Swapping for Keys:
By swapping the axes of the key tensor, we ensure that the subsequent matrix multiplication aligns the correct dimensions. This is crucial for computing the dot products between each query and every key.Scaling:
Dividing the dot product by the square root of the key dimension helps to stabilize gradients during training, especially when the dimension is large.Masking:
If there is a need to ignore certain positions (for example, padding tokens), an optional mask replaces these positions with a very low value, ensuring they have a negligible effect.Softmax and Weighted Sum:
The softmax function turns the scaled dot products into a probability distribution. These probabilities are then used to weight the value vectors, resulting in a focused output that highlights the most relevant parts of the input.
2. Multi-Head Attention
Multi-head attention enhances the model's ability to capture different patterns by running several attention heads in parallel. Each head operates independently and processes the input data, after which their outputs are concatenated along a feature dimension. This technique allows the model to jointly attend to information from various representation subspaces.
class MultiHeadAttention(nn.Module):
d_model: int = 2
row_dim: int = 0
col_dim: int = 1
num_heads: int = 1
def setup(self):
# Initialize a list of attention heads.
self.heads = [Attention(d_model=self.d_model,
row_dim=self.row_dim,
col_dim=self.col_dim)
for _ in range(self.num_heads)]
def __call__(self, encodings_for_q, encodings_for_k, encodings_for_v):
# Run each attention head independently and collect their outputs.
head_outputs = [head(encodings_for_q, encodings_for_k, encodings_for_v)
for head in self.heads]
# Concatenate the outputs along the specified dimension.
return jnp.concatenate(head_outputs, axis=self.col_dim)
Initialization with
setup
:
Thesetup
method is used to create multiple instances of the single-head attention module. This ensures that each head has its own set of parameters.Independent Processing:
Each head processes the input encodings separately, which means that different aspects of the input data can be captured by different heads.Concatenation:
Once all heads have produced their outputs, these are concatenated along a chosen dimension. This combined output carries richer information than any single head could provide alone.
3. Testing the Attention Modules
This section demonstrates how to test both the single-head and multi-head attention modules. We define sample token encodings as JAX arrays, initialize the modules using a random key (essential for Flax parameter initialization), and apply the modules to compute the attention outputs.
# Sample token encodings (3 tokens, each with 2 features)
encodings_for_q = jnp.array([[1.16, 0.23],
[0.57, 1.36],
[4.41, -2.16]])
encodings_for_k = jnp.array([[1.16, 0.23],
[0.57, 1.36],
[4.41, -2.16]])
encodings_for_v = jnp.array([[1.16, 0.23],
[0.57, 1.36],
[4.41, -2.16]])
# Create a random key for parameter initialization.
key = jax.random.PRNGKey(42)
# --- Single-Head Attention Test ---
attention_module = Attention(d_model=2, row_dim=0, col_dim=1)
params = attention_module.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
single_head_output = attention_module.apply(params, encodings_for_q, encodings_for_k, encodings_for_v)
print("Single-head attention output:")
print(single_head_output)
# --- Multi-Head Attention Test (1 head) ---
multi_head_module_1 = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=1)
params_multi1 = multi_head_module_1.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
multi_head_output_1 = multi_head_module_1.apply(params_multi1, encodings_for_q, encodings_for_k, encodings_for_v)
print("Multi-head attention (1 head) output:")
print(multi_head_output_1)
# --- Multi-Head Attention Test (2 heads) ---
multi_head_module_2 = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=2)
params_multi2 = multi_head_module_2.init(key, encodings_for_q, encodings_for_k, encodings_for_v)
multi_head_output_2 = multi_head_module_2.apply(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)
print("Multi-head attention (2 heads) output:")
print(multi_head_output_2)
Single-head attention output:
[[1.668201 2.6169908 ]
[2.433429 3.3817132 ]
[0.51508707 1.4933776 ]]
Multi-head attention (1 head) output:
[[-0.7741511 -0.24243875]
[-1.3947037 0.28557885]
[-0.08808593 -0.9197984 ]]
Multi-head attention (2 heads) output:
[[-0.7741511 -0.24243875 2.0704143 -2.0301726 ]
[-1.3947037 0.28557885 0.04033631 -0.86105233]
[-0.08808593 -0.9197984 3.9204044 -3.142049 ]]
Sample Data:
The arrays represent three tokens with two features each. This is a simplified setup to verify that our attention modules work as intended.Random Key for Initialization:
In Flax, a random key is necessary to initialize model parameters. Using a fixed key ensures that the results are reproducible.Module Application:
The modules are first initialized with the given inputs, and then applied to produce attention outputs. This helps confirm that the implementation is correct and functioning.
4. Benchmarking & Speed-up with JIT
One of the major strengths of JAX is its ability to use just-in-time (JIT) compilation. JIT compilation transforms your Python functions into highly optimized machine code. This section benchmarks the multi-head attention module, comparing the execution time with and without JIT compilation.
import time
# Define a function to repeatedly run multi-head attention for benchmarking purposes.
def run_multi_head(params, module, iterations=1000):
for _ in range(iterations):
_ = module.apply(params, encodings_for_q, encodings_for_k, encodings_for_v)
# Create a JIT-compiled version of the multi-head attention call for 2 heads.
jit_multi_head = jax.jit(lambda params, q, k, v: multi_head_module_2.apply(params, q, k, v))
# Warm-up call to trigger the JIT compilation process.
_ = jit_multi_head(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)
# Benchmark the non-JIT version.
start = time.perf_counter()
run_multi_head(params_multi2, multi_head_module_2, iterations=1000)
end = time.perf_counter()
print("Non-JIT execution time: {:.6f} seconds".format(end - start))
# Benchmark the JIT version (after warm-up).
start = time.perf_counter()
for _ in range(1000):
_ = jit_multi_head(params_multi2, encodings_for_q, encodings_for_k, encodings_for_v)
end = time.perf_counter()
print("JIT execution time: {:.6f} seconds".format(end - start))
Non-JIT execution time: 25.080998 seconds
JIT execution time: 0.020293 seconds
Warm-Up Phase:
The first call to a JIT-compiled function includes the compilation overhead. A warm-up call ensures that the timing measurements reflect only execution time and not the compilation time.Repeated Execution:
The multi-head attention module is run repeatedly (1,000 iterations) to obtain a reliable measurement of execution time.Timing Measurement:
Using Python’stime.perf_counter()
, we measure and compare the time taken for non-JIT and JIT-compiled executions. Typically, the JIT version demonstrates significant speed improvements.
Conclusion
In this tutorial, we explored the implementation of attention mechanisms using JAX and Flax. Starting from the basics of single-head attention to the more advanced multi-head attention, we discussed each component in detail and demonstrated how to optimize performance using JIT compilation. As someone who has predominantly worked with PyTorch and TensorFlow, this exercise was a valuable opportunity for me to learn JAX in a practical, hands-on way.
I hope this tutorial has been helpful and clear, even if there might be areas for improvement. Remember, learning is an ongoing process, and I'm excited to continue refining these concepts and implementations. Thank you for joining me on this learning journey—let’s keep exploring and building together!