Tensorflow tf.split() function can allow us to split a tensor into some sub tensors, here are some examples:
Split a Tensor to Sub Tensors with tf.split()
In this tutorial, we will discuss some tips on use tf.split(), you can learn how to use this function correctly by our tutorial.
Syntax of tf.split()
tf.split( value, num_or_size_splits, axis=0, num=None, name='split' )
As to tf.split(), there are some very important parameters you must notice.
Important parameters
value: a tensor you want to split
num_or_size_splits: this parameter determines the size or shape of each sub tensor, it is often a list, such as [1, 3, 5]
axis: this parameter determines how to split a tensor into sub tensors.
Return
tf.split() will return a list which contains sub tensors.
Here we will use some examples to explain how to use this function correctly.
Create a 2 * 3 * 4 shape tensor
#coding=utf-8 import tensorflow as tf w = tf.Variable(tf.random_uniform([2,3,4], -1, 1))
We should notice: there are 2 elements on axis = 0, 3 elements on axis = 1 and 4 elements on axis = 2.
To understand the relation between tensor axis and shape, you can refer to this tutorial.
Understand Tensor Axis and Shape with Examples: A Beginner Guide
Split a tensor to 2 sub tensors on axis = 0
We know there are only 2 elements on axis = 0, which mean the sum of num_or_size_splits shoud be 2.
sub_w = tf.split(w,num_or_size_splits = [1, 1]) print(type(sub_w)) print(sub_w)
The result is:
<class 'list'> [<tf.Tensor 'split:0' shape=(1, 3, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 3, 4) dtype=float32>]
From the result, we can find:
1.The return variable sub_w, the type of which is python list.
2.There are 2 tensors in sub_w, the shape of each sub tensor is 1* 3 * 4.
If the sum of num_or_size_splits is not equal to 2, how about?
sub_w = tf.split(w,num_or_size_splits = [1, 2])
Then you will get error: ValueError: Sum of output sizes must match the size of the original Tensor along the split dimension
Split a tensor to 2 sub tensors on axis = 1
There are 3 elements on axis = 1, we also be sure that the sum of num_or_size_splits is equal to 3.
sub_w = tf.split(w,num_or_size_splits = [1, 2], axis= 1)
Then you will get two sub tensor, one is 2 * 1 * 4 , the other is 2 * 2 * 4
[<tf.Tensor 'split:0' shape=(2, 1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(2, 2, 4) dtype=float32>]