Tutorial Example

Understand TensorFlow tf.boolean_mask(): Remove Data From a Tensor – TensorFlow Tutorial

If you need to shield some elements in a tensor, tf.boolean_mask() may be a good choice. In this tutorial, we will use some examples to show you how to use it correctly.

Syntax

tf.boolean_mask() is defined as:

tf.boolean_mask(
    tensor,
    mask,
    name='boolean_mask',
    axis=None
)

where:

tensor:N-D tensor.

mask: K-D boolean tensor or numpy.ndarray, K <= N and K must be known statically. It is very important, we will use it to remove some elements from tensor.

name: A name for this operation (optional).

axis: A 0-D int Tensor representing the axis in tensor to mask from.

Here we will write some examples to show how to use this function.

Remove an element from a tensor on axis = 0

import tensorflow as tf
import numpy as np

x = np.array([[2,2,3],[6,7,2],[1,2,2]], dtype = np.float32)

#remove the second element from tensor x
mask = np.array([True, False, True])

Here mask is a numpy.ndarray, the second value is False, which means we will remove the secondĀ  element from tensor x.

x2 = tf.boolean_mask(x, mask, axis = 0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(x2))

Run this code, we will get x2:

[[2. 2. 3.]
 [1. 2. 2.]]

We also can make mask be a tensor to remove data from tensor.

mask = np.array([True, False, True])
maskx = tf.convert_to_tensor(mask, dtype = tf.bool)

x2 = tf.boolean_mask(x, maskx, axis = 0)

Run this code, x2 also is:

[[2. 2. 3.]
 [1. 2. 2.]]