Understand TensorFlow tf.slice(): Split a Submatrix From a Tensor – TensorFlow Tutorial

By | August 11, 2020

TensorFlow tf.slice() allow us to extract a submatrix from a tensor. In this tutorial, we will use some examples to illustrate how to use it.

Syntax

tf.slice() is defined as:

tf.slice(
    input_,
    begin,
    size,
    name=None
)

It can extract a submatrix from input_ based on begin and size.

Parameters

input_: a tensor should be splited

begin: the location we start to split a submatrx from input_

size: the size and shape of submatrix extracted from input_

We should notice: len(begin) = len(size) = input_.dim and the value of size is the shape of submatrx.

Here we will some examples to show how to use tf.slice()

Create a input tensor

import tensorflow as tf

sess = tf.Session()
input = tf.constant([[[1, 1, 1], [2, 2, 2]],
                     [[3, 3, 3], [4, 4, 4]],
                     [[5, 5, 5], [6, 6, 6]]])

The shape of input is [3, 3, 3]

Extract a submatrx

data = tf.slice(input, [1, 1, 0], [1, 1, 3])
print(sess.run(data))

The begin = [ 1, 1, 0], which means we will start to extract a submatrix from 3.

The size = [1, 1, 3], which means we will get 1 number on axis= 0, 1 number on axis = 1 and 3 numbers on axis = 2.

The shape of subtensor will be [1, 1, 3]

understand the begin of tf.slice() in tensorflow

Run this code, you will get:

[[[4 4 4]]]

If we change the begin and size

data = tf.slice(input, [1, 1, 0], [2, 1, 2])
print(sess.run(data))

understand the begin and size of tf.slice() in tensorflow

Run this code, we will get:

[[[4 4]]

 [[6 6]]]

The shape of data is [2, 1, 2]

Leave a Reply