I have a situation where for each mini-batch, I have multiple nested data, for which model need to be trained.
for idx, batch in enumerate(train_dataloader):
data = batch.get("data").squeeze(0)
op = torch.zeros(size) #zero_initializations
for i in range(data.shape):
current_data = data[i, ...]
start_to_current_data = data[:i+1, ...]
target = some_transformation_func(start_to_current_data)
op = model(current_data, op)
loss = criterion(op, target)
But when I start training, I get the following error RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. Setting retain_graph=True increase the memory usage and I can not train the model. How can I fix this.
I’m not sure this is the cause of your issue, but it looks like your model output
op isn’t used in the loss calculation at all?
It was a typo. I fixed it.
I think there might be an issue because the same data tensor is used multiple times, and I am wondering if there is a gradient calculation here. Can you see what happens when you set
data.requires_grad = False when it is returned by
A miscellaneous note: You can try timing to see if there is an actual difference, but I don’t think the line
op = torch.zeros(size) #zero_initializations will actually save time based on how allocations are done. If anything it might just make autograd more confusing if the assignment happens with the actual tensor values rather than just the object references.
Setting op = torch.zeros(size) #zero_initializations does not have anything to do with the timing. It is the network architecture where output of the model is fed back as input along with other input. For the very first time it is initialized as zeros.
@ptrblck Can you help me out in this case.
This might be the culprit, as the computation graph might be kept alive using this approach.
Assuming you don’t want to backpropagate through the previous iterations, you could detach the tensor via
op.detach_() in each iteration.