Removing the effect of forward pass and backward pass

Hi, I have a problem which is bothering me for a month, any suggestion would be appreciated.

I have a custom network with 2 forward passes, gradients of the 1st forward pass is first multiplied with a predefined constant, then accumulated with gradients of the 2nd forward pass.

A special case is, when constant == 0, I need “model.weight.grad = None” rather than “model.weight.grad *= constant” due to this post.

The problem is, when constant == 0, after zeroing all the gradients which is generated in the 1st forward pass, I cannot get expected results, which is running the network without the 1st forward pass.

Here is an example code for what I am doing, in the below example code, I remove unnecessary parts.

code part1 (what I need)

# forward no.1
output = model(input)
loss = criterion(output, target)
loss.backward()

# gradients multiplied with a predefined constant (const)
if const == 0:
        model.module.conv1.weight.grad = None
        model.module.conv2.weight.grad = None
        ...
else:
        model.module.conv1.weight.grad *= const
        model.module.conv2.weight.grad *= const
        ...

# forward no.2
output = model(input)
loss = criterion(output, target)
loss.backward()

# sgd step
optimizer.step()
optimizer.zero_grad()

code part2 (my expected results when constant == 0)

## forward no.1
# output = model(input)
# loss = criterion(output, target)
# loss.backward()

## gradients multiplied with a predefined constant (const)
# if const == 0:
        # model.module.conv1.weight.grad = None
        # model.module.conv2.weight.grad = None
        # ...
# else:
        # model.module.conv1.weight.grad *= const
        # model.module.conv2.weight.grad *= const
        # ...

# forward no.2
output = model(input)
loss = criterion(output, target)
loss.backward()

# sgd step
optimizer.step()
optimizer.zero_grad()

When I check my code, the results are far from expected. My expected results when constant == 0 is able to generate with code part2 (comment all codes related to 1st forward pass).

Does it mean that I am not able to remove the effect of the 1st forward pass with “model.weight.grad = None”?
What is the solution if I need to remove the effect of the 1st forward pass?

Another way to explain my question is, I’d like to know how can I remove the effect of below 3 lines.

output = model(input)
loss = criterion(output, target)
loss.backward()

Thanks for reading this post, any suggestions would be appreciated.

Are you using batch norm layers in the model?
Note that during training, the running estimates will be updated, so even if you properly zero out the gradients and “remove the backward” pass, the output will still change in the next forward pass with the same inputs.
Could you try to set your model to .eval() to disable this effect for the sake of debugging?