Introduction
Einops is a powerful tool for tensor operations, including reduce, rearrange, repeat, and more. Its most impressive feature is its ability to make code more readable and reliable. Additionally, it includes nn.Module
functionality, which is compatible with PyTorch for directly modifying tensor shapes.
In this tutorial, we’ll use the MNIST dataset for demonstration purposes.
Installation
To install the stable version of Einops, use pip:
pip install einops
For the latest version, you can install it directly from the GitHub repository:
pip install git+https://github.com/arogozhnikov/einops
Operations - Rearrange
Changing Axes
The rearrange
function allows you to change the shape of a tensor. For instance, treating an image as a matrix, you can perform a transpose operation. If the original shape of the tensor is (h, w)
, transposing changes it to (w, h)
. With Python, you can write:
# Shape of data: (h, w)
transpose_data = data.transpose(0, 1)
print(transpose_data.shape)
>> torch.Size([w, h])
While this code is straightforward, it doesn’t clearly indicate the shape transformations. Using Einops, the same operation is more intuitive:
from einops import rearrange
data = rearrange(data, "h w -> w h")
Here, the transformation is described explicitly: "h w -> w h"
. This syntax is clearer and easier to understand.
Flattening
Flattening combines multiple axes into a single dimension. For example, consider a tensor with shape (batch_size, width, height)
:
# Shape of data: (batch_size, width, height)
# Shape of stack_data: (batch_size * width, height)
stack_data = rearrange(data, "b h w -> (b h) w")
You can also rearrange dimensions in a non-sequential order:
stack_data = rearrange(data, "b h w -> h (b w)")
Output examples:
Unflattening
In addition to combining axes, you can split an axis into multiple dimensions:
decompose_data = rearrange(data, "(b1 b2) h w -> (b1 h) (b2 w)", b1=2)
Here, the batch size of 4
is split into two dimensions: b1=2
and b2=2
.
Operations - Reduce
The reduce
function simplifies tensors by reducing dimensions. It shares the same syntax as rearrange
but allows changes in dimensionality. For instance, consider a tensor with shape (batch_size, height * h1, width * w1)
.
Mean Averaging
Mean averaging reduces image size by averaging pixel values:
reduce_img = reduce(data, "b (h h1) (w w1) -> h (b w)", "mean", h1=2, w1=2)
# Check the shape
print(f"Before: {data.shape}")
print(f"After: {reduce_img.shape}")
>>> Before: torch.Size([4, 28, 28])
>>> After: torch.Size([14, 56])
Output:
Max Pooling
Similarly, you can apply max pooling to reduce size:
reduce_img = reduce(data, "b (h h1) (w w1) -> h (b w)", "max", h1=2, w1=2)
# Check the shape
print(f"Before: {data.shape}")
print(f"After: {reduce_img.shape}")
>>> Before: torch.Size([4, 28, 28])
>>> After: torch.Size([14, 56])
Operations - Repeat
While reduce
decreases dimensions, repeat
increases them. For example, to expand an image tensor from (b, h, w)
to (b, h, w, 3)
:
# Shape of data: (batch_size, width, height)
repeat_data = repeat(data, "b h w -> b h w c", c=3)
print(repeat_data.shape)
>> torch.Size([4, 28, 28, 3])
Duplicating dimensions can also be done easily:
repeat_data = repeat(data, "b h w -> (b h) (w1 w)", w1=3)
PyTorch Layers
Einops operations can be integrated into PyTorch layers, making them ideal for use in models:
from einops.layers.torch import Rearrange
model = nn.Sequential(
nn.Conv2d(...),
nn.Conv2d(...),
...
Rearrange("b c h w -> b (c h w)")
...
)
Application - Attention
Einops simplifies implementing the attention mechanism, which calculates input weights based on dot products of queries and keys. Below is an example implementation:
import torch
import torch.nn as nn
from einops import rearrange
class Attention(nn.Module):
def __init__(self, n_head: int, d_in: int, d_model: int, dropout=0.1):
super().__init__()
self.n_head = n_head
assert d_model % n_head == 0
self.q = nn.Linear(d_in, d_model)
self.k = nn.Linear(d_in, d_model)
self.v = nn.Linear(d_in, d_model)
self.out = nn.Linear(d_model, d_in)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
q = rearrange(self.q(x), 'b l (h q) -> h b l q', h=self.n_head)
k = rearrange(self.k(x), 'b t (h q) -> h b t q', h=self.n_head)
v = rearrange(self.v(x), 'b t (h v) -> h b t v', h=self.n_head)
attn = torch.einsum('hblq,hbtq->hblt', [q, k]) / q.shape[-1] ** 0.5
if mask is not None:
attn = attn.masked_fill(mask[None], float('-inf'))
attn = torch.softmax(attn, dim=3)
output = torch.einsum('hblt,hbtv->hblv', [attn, v])
output = rearrange(output, 'h b l v -> b l (h v)')
output = self.dropout(self.out(output))
output = self.layer_norm(output + x)
return output