numpy.squeeze() function can allow us to remove axes of length 1, we will use some simple examples to show you how to use this function correctly in this tutorial.
Syntax
numpy.squeeze(a, axis=None)
It will remove dimensions of length 1
Parameters
a: a numpy array that will be processed
axis: Selects a subset of the entries of length one in the shape. If an axis is selected with shape entry greater than one, an error is raised.
Here is an example:
import numpy as np a = [[[10, 2, 3]]] a = np.array(a) print(np.shape(a))
a is a numpy array, the shape of it is (1, 1, 3)
a_sque = np.squeeze(a) print(np.shape(a_sque))
the shape of a_sque is (3,). axis[0] and axis[1] are removed.
a_sque will be:
[10 2 3]
If you only want to remove specific axis, you can use axis parameter.
You should notice: you only can set the axis where shape = 1.
a = [[[10, 2, 3]], [[10, 2, 3]]] a = np.array(a) a = np.reshape(a, [2,1,3,1]) print(np.shape(a))
The shape of a will be (2, 1, 3, 1). axis[1]= 1 and axis[3] = 1, which means we can use axis = 1 or 3 in numpy.squeeze()
Here is an example:
a_sque = np.squeeze(a, axis= 1) print(np.shape(a_sque)) print(a_sque)
The shape of a_sque will be (2, 3, 1)
It will be:
[[[10] [ 2] [ 3]] [[10] [ 2] [ 3]]]
a_sque = np.squeeze(a, axis= 3) print(np.shape(a_sque)) print(a_sque)
The shape of a_sque will be (2, 1, 3)
It will be:
[[[10 2 3]] [[10 2 3]]]