import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight_mul = nn.Parameter(torch.randn(D,))
self.weight = nn.Parameter(torch.randn(D,))
def forward(self, x):
x = x * self.temp_weight
return x
torch.manual_seed(0)
D = 5
x = torch.randn(D,).cuda()
model = Model()
model.cuda()
model.temp_weight = model.weight * model.weight_mul
model.cpu(); model.cuda()
output = model(x)
output.sum().backward(inputs=[model.weight, model.weight_mul])
print("model.weight.grad", model.weight.grad)
print("model.weight_mul.grad", model.weight_mul.grad)
To my surprise, the .grad are None. There are two ways to get the backward() to work. One is to remove the inputs to backward(), the other is to remove model.cpu(); model.cuda(). But why?
In this case, weight_mul gets gradient but model.weight doesn’t. Removing the inputs both’d get gradients. The computation graphs according to torchviz are identical in both cases. So somehow module.cpu() is disconnecting the computation graph but not totally?
May be related: This works
def forward(self, x):
x = x * self.weight
return x
x = torch.randn(D,).cuda()
model = Model().cuda()
output = model(x)
model.cpu()
output.sum().backward()
But this gives a device mismatch error in the backward():
def forward(self, x):
x = x * self.temp_weight
return x
x = torch.randn(D,).cuda()
weight_mul = torch.randn(D,).cuda().requires_grad_()
model = Model().cuda()
model.temp_weight = model.weight * weight_mul
output = model(x)
model.cpu()
output.sum().backward()