I have a train loop where I would like to update the parameters just every n_update batches. I cannot just increase
My current code looks like this:
y = torch.Tensor().to(device) w = torch.Tensor().to(device) loss_mean = 0 for batch,(X,yi,_) in enumerate(dataloader): # device X = X.float().to(device) yi = yi.float().to(device) # send batch through network wi = model(X) if (batch>0) and (batch%n_update==0) or (batch==len(dataloader)-1): # compute loss and backpropagate loss = loss_fn(w, y) loss_mean += loss.item() optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() # assign new first values to w and y w = wi.clone() y = yi.clone() else: # if not update step, concat along batch dimension w = torch.cat([w,wi.clone()],dim=0) y = torch.cat([y,yi.clone()],dim=0)
In every loop, batch predictions
wi and labels
yi are collected in
y by concatenating them along the batch dimension (
dim=0). Thus, after n_update steps,
y first dimension’s will have size
n_update*batch_size. Then, every
n_update steps, the gradients are computed based on the entire batch of size
n_update * batch_size. Furthermore, after running a parameter update, I assign the new values for
y as the current values of the batch, in order not to miss that batches information during training.
Executing this script yields the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor ] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Initially, I ran the code without the
.clone()-suffixes and without
retain_graph=True; this did not work. I read in other posts that these fixes might solve the problem, but it did not work either. The problem seems to be that I perform an inplace somewhere, however my
model() is free of inplace operations (I write for instance
x = x + a instead of
x+=a). Therefore, I suspect that this is related to the structure of the training loop (maybe the concatenation?). Strangely, the training loop does not even run for
n_update=1, which should recover the original training loop, where updates are performed as usual after each batch.
I appreciate any help on this! Thanks!