TensorFlow tf.slice() allow us to extract a submatrix from a tensor. In this tutorial, we will use some examples to illustrate how to use it.
Syntax
tf.slice() is defined as:
tf.slice( input_, begin, size, name=None )
It can extract a submatrix from input_ based on begin and size.
Parameters
input_: a tensor should be splited
begin: the location we start to split a submatrx from input_
size: the size and shape of submatrix extracted from input_
We should notice: len(begin) = len(size) = input_.dim and the value of size is the shape of submatrx.
Here we will some examples to show how to use tf.slice()
Create a input tensor
import tensorflow as tf sess = tf.Session() input = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]])
The shape of input is [3, 3, 3]
Extract a submatrx
data = tf.slice(input, [1, 1, 0], [1, 1, 3]) print(sess.run(data))
The begin = [ 1, 1, 0], which means we will start to extract a submatrix from 3.
The size = [1, 1, 3], which means we will get 1 number on axis= 0, 1 number on axis = 1 and 3 numbers on axis = 2.
The shape of subtensor will be [1, 1, 3]
Run this code, you will get:
[[[4 4 4]]]
If we change the begin and size
data = tf.slice(input, [1, 1, 0], [2, 1, 2]) print(sess.run(data))
Run this code, we will get:
[[[4 4]] [[6 6]]]
The shape of data is [2, 1, 2]