Change specified indices of torch tensors

Hi, I come with a reproducible example of how to change some values of a tensor to 0 in some previously defined indices. Imagine that I have a tensor a of dimension (n_sensors,batch_size,dim1,dim2). For each batch I want to set to 0 the 5 highest values of dim2 for each sensor and each value of dim1. That is, I want to do this:

import matplotlib.pyplot as plt
import torch 

a = torch.rand((5,4,10,15))
top_att = torch.sum(a.mean(0),axis=1)
top_att = top_att.argsort(axis = 1)[:, -5:] 
top_att = top_att.sort()[0]
print(top_att.shape)

plt.figure()
plt.imshow(a.mean(0)[0])
plt.colorbar(fraction=0.046, pad=0.04)
plt.clim(0,1)
plt.show()

for sensor in range(a.shape[0]):
    for time_step in range(a.shape[2]):
        for batch in range(a.shape[1]):
            indexes = top_att[batch]
            a[sensor,batch,time_step,indexes] = 0
            

plt.figure()
plt.imshow(a.mean(0)[0])
plt.colorbar(fraction=0.046, pad=0.04)
plt.clim(0,1)
plt.show()

This way it works fine for me, but I want to look for a more optimal way, without using loops, how can I do it?