Training gets slow down by each batch slowly

My nn.Module had a variable which seems to be outside of the training loop but accumulates gradient across loops like this:

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.variable = torch.tensor([1., 2.], requires_grad=True)
        self.bad_variable_used_across_loop = torch.tensor([-1.])
    def forward(self, x):
        self.bad_variable_used_across_loop = x @ self.variable + self.bad_variable_used_across_loop
        some_result = x @ self.variable + self.bad_variable_used_across_loop
        return some_result

Here I make bad_variable_used_across_loop an attribute of Foo only to record the value of for further use. But this variable keeps gradient flow through across batch!
To solve this, add model.bad_variable_used_across_loop.detach() at the end of each training loop.

model = Foo()
for step in range(100000):
    start = time.time()
    x = torch.randn([10, 2])
    loss = model(x).sum()
    loss.backward()
    end = time.time()
    model.bad_variable_used_across_loop.detach() # detach it
    print(f'step {step:05d}: {end-start:.2f}s')

Hi,

Small note:
If your ā€œvariableā€ is a learnt parameter, it should be an nn.Parameter.

Hi, similar issue occurs to me while training, but my problem is that after I stop the process, load latest model and continue training, the speed becomes normal. I know my problem can be fixed by restarting the process, but just wondering why and how can I solve the problem. I am using 2080 ti and this happens both while training YOLO and my customized model.
Thank you in advance.

Hi,

Given what you describe, I guess the issue is with the way to initialize your Parameters when there is no checkpoint to load.

hey, could you please explain more about the usage .detach() a bit more in the case of accumulating of loss. Sorry if my question is too basic, Iā€™m still a new hand to PyTorch.

epoch_loss = 0
n_train = len(train_loader)
 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
      for batch in train_loader:
          net.train()
          imgs = batch['image']
          true_masks = batch['labels']
          imgs = imgs.to(device=device, dtype=torch.float32)
          mask_type = torch.float32 if net.n_classes == 1 else torch.long
          true_masks = true_masks.to(device=device, dtype=mask_type)
          logits,probs,masks_pred = net(imgs)
          logits = torch.squeeze(logits,1)
          loss = criterion(logits, true_masks)
          epoch_loss += loss.item()/n_train
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          # to be continued 

If I understood it correctly in this case, epoch_loss = 0, wouldnā€™t let the loss being saved from one iteration to the another. And, this is called .detach().
Thank u.

Hi,

In this case, you donā€™t need to explicitly call .detach() because you extract a python number directly with .item() and so this breaks the graph.

1 Like

hi, thanks for your explanation. It makes sense to me now.

Hey, I am facing the same problem where the training speed for each batch grows within an epoch. I dont quite understand how did you solve this problem?

2 Likes

set pin_memory=True solved my problem

Could you elaborate on that? I believe I am facing a similar issue but dont know how to solve it

just run into this issue on some code where each epoch ~50k minibtaches are processed for training.
processing an epoch starts at ~13.6it/s and ends up with ~6it/s.
after some digging, it turns out, this slowness is caused by this operation that is executed every minibatch:

stats = stats + new_list

new_list holds 32 values. unfortunately, this is a very expensive operation since it will create a new list for stats. the cost depends on the size of the list. this stats is reset every epoch, but, it still grows quickly during iteration over minibatches. toward the end of an epoch, stats holds about 1.6 million elements.
adding new elements should be done directly on the existing list stats using in-place operation such as:

stats += new_list
stats.extend(new_list)
stats.append(1.)

these operations wont create a new list.
using the second method stats.extend(new_list) maintains the processing time of each minibatch at the same level: ~13.6it/s during the entire epoch.
note: all values are detached, and on cpu.