Split a Tensor to Sub Tensors with tf.split() – TensorFlow Tutorial

By | June 12, 2019

tf.split() function can split a tensor to some sub tensors in tensorflow. In this tutorial, we will write an example to illustrate how to use this function.

Step 1. Define a tensor to be splitted

w = tf.Variable(tf.random_uniform([2,3,4], -1, 1))

Step 2. Split a tensor by a integer

w_1 = tf.split(axis=2, num_or_size_splits=2, value=w)

Step 3. Split a tensor by a integer list

w_2 = tf.split(axis=2, num_or_size_splits=[2,1,1], value=w)

The full example code is here.

import tensorflow as tf;
import numpy as np

w = tf.Variable(tf.random_uniform([2,3,4], -1, 1))

w_1 = tf.split(axis=2, num_or_size_splits=2, value=w)

w_2 = tf.split(axis=2, num_or_size_splits=[2,1,1], value=w)

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
   
    w, w_1, w_2 = (sess.run([w, w_1, w_2]))
    
    print 'tf.split original'
    print w
    print 'tf.split with integer 2'
    print w_1
    print 'tf.split with list [2, 1, 1]'
    print w_2

The w is:

Sub tensors splitted by integer 2 is:

Sub tensors splitted by integer list [2, 1, 1] is:

Note: you should notice the order of split a tensor.

Leave a Reply