Training gets slow down by each batch slowly

My nn.Module had a variable which seems to be outside of the training loop but accumulates gradient across loops like this:

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.variable = torch.tensor([1., 2.], requires_grad=True)
        self.bad_variable_used_across_loop = torch.tensor([-1.])
    def forward(self, x):
        self.bad_variable_used_across_loop = x @ self.variable + self.bad_variable_used_across_loop
        some_result = x @ self.variable + self.bad_variable_used_across_loop
        return some_result

Here I make bad_variable_used_across_loop an attribute of Foo only to record the value of for further use. But this variable keeps gradient flow through across batch!
To solve this, add model.bad_variable_used_across_loop.detach() at the end of each training loop.

model = Foo()
for step in range(100000):
    start = time.time()
    x = torch.randn([10, 2])
    loss = model(x).sum()
    loss.backward()
    end = time.time()
    model.bad_variable_used_across_loop.detach() # detach it
    print(f'step {step:05d}: {end-start:.2f}s')

Hi,

Small note:
If your “variable” is a learnt parameter, it should be an nn.Parameter.

Hi, similar issue occurs to me while training, but my problem is that after I stop the process, load latest model and continue training, the speed becomes normal. I know my problem can be fixed by restarting the process, but just wondering why and how can I solve the problem. I am using 2080 ti and this happens both while training YOLO and my customized model.
Thank you in advance.

Hi,

Given what you describe, I guess the issue is with the way to initialize your Parameters when there is no checkpoint to load.

hey, could you please explain more about the usage .detach() a bit more in the case of accumulating of loss. Sorry if my question is too basic, I’m still a new hand to PyTorch.

epoch_loss = 0
n_train = len(train_loader)
 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
      for batch in train_loader:
          net.train()
          imgs = batch['image']
          true_masks = batch['labels']
          imgs = imgs.to(device=device, dtype=torch.float32)
          mask_type = torch.float32 if net.n_classes == 1 else torch.long
          true_masks = true_masks.to(device=device, dtype=mask_type)
          logits,probs,masks_pred = net(imgs)
          logits = torch.squeeze(logits,1)
          loss = criterion(logits, true_masks)
          epoch_loss += loss.item()/n_train
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          # to be continued 

If I understood it correctly in this case, epoch_loss = 0, wouldn’t let the loss being saved from one iteration to the another. And, this is called .detach().
Thank u.

Hi,

In this case, you don’t need to explicitly call .detach() because you extract a python number directly with .item() and so this breaks the graph.

hi, thanks for your explanation. It makes sense to me now.

Hey, I am facing the same problem where the training speed for each batch grows within an epoch. I dont quite understand how did you solve this problem?

2 Likes