tf.roll(): Rolls the Elements of a Tensor Along an Axis – TensorFlow Tutorial

By | December 9, 2021

In this example, we will use some examples to show you how to use tf.roll() function correctly.

tf.roll()

It is defined as:

tf.roll(
    input, shift, axis, name=None
)

It will roll the elements of a tensor along an axis.

Here shift will determines the steps and direction.

Look at this example:

import tensorflow as tf
import numpy as np
lx = [[1,2,3],[4,5,6], [7,8,10]]
a = tf.convert_to_tensor(lx, dtype = tf.float32)
b = tf.roll(a, 2, axis = 0)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
 
    _w = sess.run(b)
    print(_w)

Here shift = 2, it is >0, axis = 0, you will get this result:

tf.roll() example 1

Because shift = 2 >0, elements on axis = 0 will move from top to bottom.

How about shift < 0

Here is an example:

import tensorflow as tf
import numpy as np
lx = [[1,2,3],[4,5,6], [7,8,10]]
a = tf.convert_to_tensor(lx, dtype = tf.float32)
b = tf.roll(a, -1, axis = 0)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
 
    _w = sess.run(b)
    print(_w)

Here shift = -1 <0, elements on axis =0 will move from bottom to top.

Run this code, you will get:

tf.roll() example 2

How about on aixs = 1?

We will two examples to show you how to use.

When shift = 2, which is bigger than 0. It means elements on axis = 1 will move from left to right.

Look at example code below:

import tensorflow as tf
import numpy as np
lx = [[1,2,3],[4,5,6], [7,8,10]]
a = tf.convert_to_tensor(lx, dtype = tf.float32)
b = tf.roll(a, 2, axis = 1)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
 
    _w = sess.run(b)
    print(_w)

Run this code, you will get this result:

tf.roll() example 3

If shift = -1 <0, elements will move from right to left.

Here is an example:

import tensorflow as tf
import numpy as np
lx = [[1,2,3],[4,5,6], [7,8,10]]
a = tf.convert_to_tensor(lx, dtype = tf.float32)
b = tf.roll(a, -1, axis = 1)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
 
    _w = sess.run(b)
    print(_w)

We will get this result:

tf.roll() example 4

Leave a Reply