Let’s say I want to check if a parameter is in a list of parameters, for instance to add it to an optimizer param group:
import torch, torchvision
m = torchvision.models.resnet18()
p0 = next(iter(m.fc.parameters()))
p_list = list(m.parameters())
p0 in p_list
Raises RuntimeError: The size of tensor a (7) must match the size of tensor b (512) at non-singleton dimension 3
But I should just be comparing object identities, not values/shapes of tensors. What’s going on here? As a workaround, I could compare object ids but this doesn’t seem like it should be necessary.
Thanks!
When you use p0 in p_list, python calls the __contains__ method of the list. It iterates through its elements and uses the == operator to check for equality.
Pytorch has overloaded the == operator for tensors to perform element-wise comparison rather than identity comparison. When comparing tensors of different shapes, PyTorch raises the RuntimeError you observed.
As a workaround for comparing objects identities, I would suggest doing it with is operator: