Hi, I come with a reproducible example of how to change some values of a tensor to 0 in some previously defined indices. Imagine that I have a tensor a
of dimension (n_sensors,batch_size,dim1,dim2)
. For each batch I want to set to 0 the 5 highest values of dim2
for each sensor and each value of dim1
. That is, I want to do this:
import matplotlib.pyplot as plt
import torch
a = torch.rand((5,4,10,15))
top_att = torch.sum(a.mean(0),axis=1)
top_att = top_att.argsort(axis = 1)[:, -5:]
top_att = top_att.sort()[0]
print(top_att.shape)
plt.figure()
plt.imshow(a.mean(0)[0])
plt.colorbar(fraction=0.046, pad=0.04)
plt.clim(0,1)
plt.show()
for sensor in range(a.shape[0]):
for time_step in range(a.shape[2]):
for batch in range(a.shape[1]):
indexes = top_att[batch]
a[sensor,batch,time_step,indexes] = 0
plt.figure()
plt.imshow(a.mean(0)[0])
plt.colorbar(fraction=0.046, pad=0.04)
plt.clim(0,1)
plt.show()
This way it works fine for me, but I want to look for a more optimal way, without using loops, how can I do it?