Einsum is a powerful operation to matrix in deep learning. In this tutorial, we will introduce how to use it for beginners.
Einsum
Einsum can be defined as:
Input: arg0, arg1, arg2
Output: dst
We can find the input is on the left, output is on the right.
We will use some examples to help you understand it
Common operations in einsum
Example 1: Matrix multiplication
Matrix multiplication is defined as:
\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}\]
The einsum expression is: ik,kj->ij
We can understand it like this:
First, we will remove \(\Sigma\)
Second: we will momve ij to right, ik,kj to left.
For example:
Example 2: Matrix diagonal elements
It can be defined as:
\[A_i = B_{ii}\]
In einsum, we can defined it as ii->i
In this example, there no sum operation. How to understand?
Example 3: Matrix trace
It can be defined as:
\[t = \Sigma_{k=1}^{n} A_{kk}\]
In einsum, we can create it as kk->
From three examples above, we can find:
1.output in einsum is optional, see example 3. It will return a scalar.
2.-> is also optional, for example:
einsum('i', a)# a
Example 4: Batch matrix multiplication
Here is the example code:
a = torch.randn (3,2,5) b = torch.randn(3,5,3) torch.einsum('ijk,ikl->ijl', [a, b]) tensor([[[ 1.0886, 0.0214, 1.0690], [ 2.0626, 3.2655, -0.1465]], [[-6.9294, 0.7499, 1.2976], [ 4.2226, -4.5774, -4.8947]], [[-2.4289, -0.7804, 5.1385], [ 0.8003, 2.9425, 1.7338]]])
Vector operations in einsum
We will list some common einsum operations on vectors.
einsum('i', a)# a einsum('i->', a)# sum(a) einsum('i,i->i', a,b)# a*b, einsum('i,i->', a,b)# inner(a, b) scalar einsum('i,j->ij', a,b)# outer(a, b)
Matrix operations in einsum
There are som common operations on matrix. They are:
('ji', A) # A.T 转置 ('ii->i', A) #diag(A) ('ii', A) # trace(A) ('ij->', A) # sum(A) ('ij->j', A) # sum(A, axis=0) ('ij->i', A) # sum(A, axis=1) ('ijk->j',A) # sum(A, axis=(0,2)) ('ijk->ik',A) # sum(A, axis=(1)) ('ij,ij-> ij', A) # A*B Hadamard dot ('ij,ji-> ij', A) # A*B.T ('ij, ij', A) # dot(A, B) ('ij,kj->ik', A) # inner(A, B) ('ij,kj->ikj', A) # A[:, ]*B ('ij,kl->ijkl', A,B) # A[:, :, , ]*B ('ijk,kl ->ijl', A, B)