How to get the some certain element of every batch

for example, the shape of the output of the network is batchsizewh=1678, how can I get the top left element of every batch, is there some pytorch operator to do it?

You could just index your output:

batch_size = 10
output = torch.randn(batch_size, 7, 8)
output[:, 0, 0]

This will return output at pos [0, 0] for every batch sample.

2 Likes

yes, thank you very much