Hi, thanks for your reply.
I’m using autograd.Function with nn.module now, but I get this error when calling loss.backward() for mySys.
I’ve seen your reply below and found my custom backward function’s output sim_out is detached from the computation graph(.grad_fn=None).
Here is a snippet of my code:
class B_nd(torch.autograd.Function):
def forward(ctx, input):
ctx.save_for_backward(input)
# some non-differentiable functions
return out
def backward(ctx, grad_output):
input, = ctx.saved_tensors
input.requires_grad_(True)
sim_model = UNet()
sim_model.load_state_dict(torch.load('/Desktop/CodeFolder/2023/pth/800.pth', map_location='cuda:0'))
sim_model.to(torch.device("cuda"))
sim_model.zero_grad()
sim_out = sim_model(input)
print(sim_out.grad_fn)
sim_out.backward(gradient=grad_output)
gradient = input.grad
return gradient
class mySys(nn.Module):
def __init__(self):
super(mySys, self).__init__()
self.A = nn.Conv2d(3, 3, kernel_size=1)
self.C = nn.Conv2d(3, 3, kernel_size=1)
def forward(self, x):
a_out = self.A(x)
b_out = B_nd.apply(a_out)
c_out = self.C(b_out)
return c_out
Any method or suggestion to attach the UNet in the backward part of B to the computation graph?