Backward(inputs=) doesn't work when the model is moved between devices

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight_mul = nn.Parameter(torch.randn(D,))
        self.weight = nn.Parameter(torch.randn(D,))

    def forward(self, x):
        x = x * self.temp_weight
        return x

torch.manual_seed(0)
D = 5

x = torch.randn(D,).cuda()
model = Model()

model.cuda()
model.temp_weight = model.weight * model.weight_mul

model.cpu(); model.cuda()
output = model(x)
output.sum().backward(inputs=[model.weight, model.weight_mul])

print("model.weight.grad", model.weight.grad)
print("model.weight_mul.grad", model.weight_mul.grad)

To my surprise, the .grad are None. There are two ways to get the backward() to work. One is to remove the inputs to backward(), the other is to remove model.cpu(); model.cuda(). But why?

torch.manual_seed(0)
D = 5

x = torch.randn(D,).cuda()
weight_mul = torch.randn(D,).requires_grad_()
model = Model()

model.cuda()
model.temp_weight = model.weight * weight_mul.cuda()

model.cpu(); model.cuda()
output = model(x)
                
output.sum().backward(inputs=[model.weight, weight_mul])
# output.sum().backward()

print("model.weight.grad", model.weight.grad)
print("weight_mul.grad", weight_mul.grad)

In this case, weight_mul gets gradient but model.weight doesn’t. Removing the inputs both’d get gradients. The computation graphs according to torchviz are identical in both cases. So somehow module.cpu() is disconnecting the computation graph but not totally?

May be related: This works

    def forward(self, x):
        x = x * self.weight
        return x

x = torch.randn(D,).cuda()
model = Model().cuda()
output = model(x)
model.cpu()
output.sum().backward()

But this gives a device mismatch error in the backward():

    def forward(self, x):
        x = x * self.temp_weight
        return x

x = torch.randn(D,).cuda()
weight_mul = torch.randn(D,).cuda().requires_grad_()
model = Model().cuda()

model.temp_weight = model.weight * weight_mul
output = model(x)
model.cpu()

output.sum().backward()

unless the model.cpu() is removed.