How to plot a heatmap by pytorch

I want to plot a heat map of features by pytorch ,but I do not know how to do it.

One way would be to convert the tensor to an ndarray and use seaborn/matplotlib to plot the heatmap.

Thank you for you help


I have a tensor with shape of [x, y, z, z]. First I converted it to numpy:
Tensor_a = (Tensor_a).cpu().numpy()
Then I tried to plot it as:
Tensor_a= sns.heatmap(Tensor_a, linewidth=0.5)

but i faced this error:
raise ValueError(‘Must pass 2-d input’)

Any help please?