In this example, we will use some examples to show you how to use tf.roll() function correctly.
tf.roll()
It is defined as:
tf.roll( input, shift, axis, name=None )
It will roll the elements of a tensor along an axis.
Here shift will determines the steps and direction.
Look at this example:
import tensorflow as tf import numpy as np lx = [[1,2,3],[4,5,6], [7,8,10]] a = tf.convert_to_tensor(lx, dtype = tf.float32) b = tf.roll(a, 2, axis = 0) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) _w = sess.run(b) print(_w)
Here shift = 2, it is >0, axis = 0, you will get this result:
Because shift = 2 >0, elements on axis = 0 will move from top to bottom.
How about shift < 0
Here is an example:
import tensorflow as tf import numpy as np lx = [[1,2,3],[4,5,6], [7,8,10]] a = tf.convert_to_tensor(lx, dtype = tf.float32) b = tf.roll(a, -1, axis = 0) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) _w = sess.run(b) print(_w)
Here shift = -1 <0, elements on axis =0 will move from bottom to top.
Run this code, you will get:
How about on aixs = 1?
We will two examples to show you how to use.
When shift = 2, which is bigger than 0. It means elements on axis = 1 will move from left to right.
Look at example code below:
import tensorflow as tf import numpy as np lx = [[1,2,3],[4,5,6], [7,8,10]] a = tf.convert_to_tensor(lx, dtype = tf.float32) b = tf.roll(a, 2, axis = 1) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) _w = sess.run(b) print(_w)
Run this code, you will get this result:
If shift = -1 <0, elements will move from right to left.
Here is an example:
import tensorflow as tf import numpy as np lx = [[1,2,3],[4,5,6], [7,8,10]] a = tf.convert_to_tensor(lx, dtype = tf.float32) b = tf.roll(a, -1, axis = 1) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) _w = sess.run(b) print(_w)
We will get this result: