Beginner Guide to Einsum for Deep Learning – Deep Learning Tutorial

By | December 10, 2021

Einsum is a powerful operation to matrix in deep learning. In this tutorial, we will introduce how to use it for beginners.

Einsum

Einsum can be defined as:

Einsum definition

Input: arg0, arg1, arg2

Output: dst

We can find the input is on the left, output is on the right.

We will use some examples to help you understand it

Common operations in einsum

Example 1: Matrix multiplication

Matrix multiplication is defined as:

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}\]

The einsum expression is: ik,kj->ij

We can understand it like this:

understand einsum with matrix multiplication

First, we will remove \(\Sigma\)

Second: we will momve ij to right, ik,kj to left.

For example:

Matrix multiplication in Einsum

Example 2: Matrix diagonal elements

It can be defined as:

\[A_i = B_{ii}\]

In einsum, we can defined it as ii->i

Matrix diagonal elements in Einsum

In this example, there no sum operation. How to understand?

Example 3: Matrix trace

It can be defined as:

\[t = \Sigma_{k=1}^{n} A_{kk}\]

In einsum, we can create it as kk->

Matrix trace in Einsum

From three examples above, we can find:

1.output in einsum is optional, see example 3. It will return a scalar.

2.-> is also optional, for example:

einsum('i', a)# a

Example 4: Batch matrix multiplication

Here is the example code:

a = torch.randn

(3,2,5)
b = torch.randn(3,5,3)
torch.einsum('ijk,ikl->ijl', [a, b])
tensor([[[ 1.0886,  0.0214,  1.0690],
         [ 2.0626,  3.2655, -0.1465]],

        [[-6.9294,  0.7499,  1.2976],
         [ 4.2226, -4.5774, -4.8947]],

        [[-2.4289, -0.7804,  5.1385],
         [ 0.8003,  2.9425,  1.7338]]])

Vector operations in einsum

We will list some common einsum operations on vectors.

einsum('i', a)# a
einsum('i->', a)# sum(a)
einsum('i,i->i', a,b)# a*b, 
einsum('i,i->', a,b)# inner(a, b) scalar
einsum('i,j->ij', a,b)# outer(a, b)

Matrix operations in einsum

There are som common operations on matrix. They are:

('ji', A) # A.T 转置
('ii->i', A) #diag(A)
('ii', A) # trace(A)
('ij->', A) # sum(A)
('ij->j', A) # sum(A, axis=0)
('ij->i', A) # sum(A, axis=1)
('ijk->j',A) # sum(A, axis=(0,2))
('ijk->ik',A) # sum(A, axis=(1))
('ij,ij-> ij', A) # A*B  Hadamard dot
('ij,ji-> ij', A) # A*B.T
('ij, ij', A) # dot(A, B) 
('ij,kj->ik', A) # inner(A, B) 
('ij,kj->ikj', A) # A[:, ]*B
('ij,kl->ijkl', A,B) # A[:, :, , ]*B
('ijk,kl ->ijl', A, B)

Leave a Reply