TensorFlow tf.where() function can help us to select tensor by condition. In this tutorial, we will discuss how to use this function correctly with some examples.
Syntax
tf.where( condition, x=None, y=None, name=None )
Return the elements, either from x or y, depending on the condition. If element in condition is True, this function will return element in x at the same position, otherwise, it will return element in y.
Parameter explained
condition: A Tensor of type bool, we shoud select elements in x or y based on it.
x: A Tensor which may have the same shape as condition
y: A Tensor which may have the same shape as x
To use this function correctly, we should notice:
- condition is a tensor of type bool
- x and y should have the same shape with condition
We will use some examples to explain them.
tf.where() example
Create condition, x and y tensor
import tensorflow as tf import numpy as np condition = tf.Variable(np.array([[True, False, False],[False, True, False],[True, True, True]]), dtype = tf.bool, name = 'condition') x = tf.Variable(np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]), dtype = tf.float32, name = 'x') y =tf.Variable(np.array([[11, 12, 13],[14, 15, 16],[17, 18, 19]]), dtype = tf.float32, name = 'y')
Here we can find condtion is a bool type. condtion, x and y tensor have the same shape.
r = tf.where(condition, x, y)
The r tensor will be:
[array([[ 1., 12., 13.], [14., 5., 16.], [ 7., 8., 9.]], dtype=float32)]
From the result we can find:
If element in condtion is True, r will save element in x at the same postion, otherwise it will save element in y. It explains why the shape of condition, x and y should be the same.
If condition is:
condition = tf.Variable(np.array([True, False, True]), dtype = tf.bool, name = 'condition')
r will be:
[[ 1. 2. 3.] [14. 15. 16.] [ 7. 8. 9.]]