for example, the shape of the output of the network is batchsizewh=1678, how can I get the top left element of every batch, is there some pytorch operator to do it?
You could just index your output
:
batch_size = 10
output = torch.randn(batch_size, 7, 8)
output[:, 0, 0]
This will return output at pos [0, 0]
for every batch sample.
2 Likes
yes, thank you very much