Hi!
I’m trying to check the shape of tensors (N, C, H, W) for my code
I know N will change, but I want to check that C, H, W are correct
What would be a good way to do this?
Thanks a lot!
Hi!
I’m trying to check the shape of tensors (N, C, H, W) for my code
I know N will change, but I want to check that C, H, W are correct
What would be a good way to do this?
Thanks a lot!
tensor.size()
or tensor.shape
will return a torch.Size
object which you can then use to compare against a reference as seen here:
x = torch.randn(2, 3, 4, 5)
print(x.size() == torch.Size([3, 4, 5]))
# False
print(x.size()[1:] == torch.Size([3, 4, 5]))
# True
Thanks very much for your help ptrblck! This method will help me catch many errors!