Training gets slow down by each batch slowly

My nn.Module had a variable which seems to be outside of the training loop but accumulates gradient across loops like this:

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.variable = torch.tensor([1., 2.], requires_grad=True)
        self.bad_variable_used_across_loop = torch.tensor([-1.])
    def forward(self, x):
        self.bad_variable_used_across_loop = x @ self.variable + self.bad_variable_used_across_loop
        some_result = x @ self.variable + self.bad_variable_used_across_loop
        return some_result

Here I make bad_variable_used_across_loop an attribute of Foo only to record the value of for further use. But this variable keeps gradient flow through across batch!
To solve this, add model.bad_variable_used_across_loop.detach() at the end of each training loop.

model = Foo()
for step in range(100000):
    start = time.time()
    x = torch.randn([10, 2])
    loss = model(x).sum()
    loss.backward()
    end = time.time()
    model.bad_variable_used_across_loop.detach() # detach it
    print(f'step {step:05d}: {end-start:.2f}s')
1 Like