My nn.Module
had a variable which seems to be outside of the training loop but accumulates gradient across loops like this:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.variable = torch.tensor([1., 2.], requires_grad=True)
self.bad_variable_used_across_loop = torch.tensor([-1.])
def forward(self, x):
self.bad_variable_used_across_loop = x @ self.variable + self.bad_variable_used_across_loop
some_result = x @ self.variable + self.bad_variable_used_across_loop
return some_result
Here I make bad_variable_used_across_loop
an attribute of Foo
only to record the value of for further use. But this variable keeps gradient flow through across batch!
To solve this, add model.bad_variable_used_across_loop.detach()
at the end of each training loop.
model = Foo()
for step in range(100000):
start = time.time()
x = torch.randn([10, 2])
loss = model(x).sum()
loss.backward()
end = time.time()
model.bad_variable_used_across_loop.detach() # detach it
print(f'step {step:05d}: {end-start:.2f}s')