(some form of) Pattern Matching for Tensors


I’m trying to check the shape of tensors (N, C, H, W) for my code
I know N will change, but I want to check that C, H, W are correct
What would be a good way to do this?

Thanks a lot!

tensor.size() or tensor.shape will return a torch.Size object which you can then use to compare against a reference as seen here:

x = torch.randn(2, 3, 4, 5)

print(x.size() ==  torch.Size([3, 4, 5]))
# False
print(x.size()[1:] == torch.Size([3, 4, 5]))
# True

Thanks very much for your help ptrblck! This method will help me catch many errors!

1 Like