Confused by RuntimeError when checking for parameter in list

Let’s say I want to check if a parameter is in a list of parameters, for instance to add it to an optimizer param group:

import torch, torchvision
m = torchvision.models.resnet18()
p0 = next(iter(m.fc.parameters()))
p_list = list(m.parameters())
p0 in p_list

Raises RuntimeError: The size of tensor a (7) must match the size of tensor b (512) at non-singleton dimension 3

But I should just be comparing object identities, not values/shapes of tensors. What’s going on here? As a workaround, I could compare object ids but this doesn’t seem like it should be necessary.
Thanks!

When you use p0 in p_list, python calls the __contains__ method of the list. It iterates through its elements and uses the == operator to check for equality.
Pytorch has overloaded the == operator for tensors to perform element-wise comparison rather than identity comparison. When comparing tensors of different shapes, PyTorch raises the RuntimeError you observed.

As a workaround for comparing objects identities, I would suggest doing it with is operator:

any(p0 is p for p in p_list)

or with id:

any(id(p0) == id(p) for p in p_list)
1 Like