How to extract sparse points in the heatmap efficiently

I am working on a vision task and I use a fully convolutional networks to produce a heatmap for images. I now want to add some supervision to certain points on the heatmaps. So, I want to get the values of certain pixels and calculate the loss. Those pixels are specified in the inputs. The number of pixels is usually smaller than 1000 and varies with different batches . Now I use two for-loops to get the values like this:

def calc(heatmap, points):
    # heatmap size: (32, 32, 128, 128)
    # points: lists
    results = []
    for i in range(heatmap.size()[0]):
        #points[i] is a list with length smaller than 30
        for j in points[i]:
            ans = []
            # j is a list with length smaller than 32
            for idx, k in enumerate(j):
                ans.append(heatmap[i, k[0], k[1], k[2]])
            results.append(torch.mean(torch.stack(ans), dim = 0))
    return results

However, it’s very inefficient. It takes about 0.1 sec to calculate the heatmap and 0.5 sec to extract those values from the heatmaps. I am looking for an efficient way to extract those sparse points.