tf.split() function can split a tensor to some sub tensors in tensorflow. In this tutorial, we will write an example to illustrate how to use this function.
Step 1. Define a tensor to be splitted
w = tf.Variable(tf.random_uniform([2,3,4], -1, 1))
Step 2. Split a tensor by a integer
w_1 = tf.split(axis=2, num_or_size_splits=2, value=w)
Step 3. Split a tensor by a integer list
w_2 = tf.split(axis=2, num_or_size_splits=[2,1,1], value=w)
The full example code is here.
import tensorflow as tf; import numpy as np w = tf.Variable(tf.random_uniform([2,3,4], -1, 1)) w_1 = tf.split(axis=2, num_or_size_splits=2, value=w) w_2 = tf.split(axis=2, num_or_size_splits=[2,1,1], value=w) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) w, w_1, w_2 = (sess.run([w, w_1, w_2])) print 'tf.split original' print w print 'tf.split with integer 2' print w_1 print 'tf.split with list [2, 1, 1]' print w_2
The w is:
Sub tensors splitted by integer 2 is:
Sub tensors splitted by integer list [2, 1, 1] is:
Note: you should notice the order of split a tensor.