In numpy or torch API , what is squeeze used for.
squeeze
removes a dimension with a size of 1.
E.g. if you have a tensor with a shape of [2, 1, 3, 1, 4, 1]
, you could squeeze
dim1, dim3, and dim5 by passing the dim
argument to squeeze
. Calling squeeze()
on a tensor without specifying a dimension will remove all dimensions with a size of 1. The docs also explain this behavior.
Yes I get that part , however my question was, under what conditions would one use this.
The use case explains itself, so whenever you would need to remove a size 1 dimension, you could use squeeze
. E.g. if you are dealing with a single sample somewhere in your code and would like to remove the batch size in order to plot the image tensor via plt.imshow
.
1 Like