Large RAM usage at the start of training

Hey guys, I have a training loop, which looks as follows:

 for epoch in range(arg.epochs):
        for step, (seq, target, reference_val) in enumerate(TrainLoader):
            model.train()
            seq, target, reference_val = seq.to(device), target.to(device), reference_val.to(device)
            model.zero_grad()
            preds = model(seq)
            loss = criterion(preds, target)
            loss.backward()
            if step % 100 == 0:
                if arg.BCELoss:
                    print('reference: ', reference_val[:5].unsqueeze(-1), '\n target: ', target[:5], '\n pred: ',
                          torch.sigmoid(preds[:5]))
                else:
                    print('reference: ', reference_val[:5].unsqueeze(-1), '\n target: ', target[:5], '\n pred: ',
                          preds[:5])

            optimizer.step()
            if optimizer.param_groups[0]["lr"] > arg.lr_end:
                scheduler.step()
          

At the time of the first print statement, about 4GB RAM are consumed. However, during the very first optimizer step, almost 10 GB are necessary. Later during the training the network will stay constantly at 4GB. Does anyone know what causes this behavior and how to prevent it?

Don’t see anything out of ordinary in your training pass. Maybe something related to precision during optimizer run?

UPD. right, in this case memory consumption will increase every run