0x97
1
I am instantiating a model, and I would like to pass the device as an arg from a configuration dataclass, and then inside the model class perform:
self.to(config.device)
but after a check, the model is not on ‘cuda’.
Is there any way of doing this?
It works for me:
class MyModel(nn.Module):
def __init__(self, device):
super().__init__()
self.fc1 = nn.Linear(1, 1)
self.fc2 = nn.Linear(1, 2)
self.to(device)
model = MyModel("cuda")
print(model.fc1.weight.device)
# cuda:0
print(model.fc2.weight.device)
# cuda:0
although I’m unsure why you want to call it inside the model’s __init__
instead of after creating the object.