Hey,
I have a train loop where I would like to update the parameters just every n_update batches. I cannot just increase batch_size
.
My current code looks like this:
y = torch.Tensor().to(device)
w = torch.Tensor().to(device)
loss_mean = 0
for batch,(X,yi,_) in enumerate(dataloader):
# device
X = X.float().to(device)
yi = yi.float().to(device)
# send batch through network
wi = model(X)
if (batch>0) and (batch%n_update==0) or (batch==len(dataloader)-1):
# compute loss and backpropagate
loss = loss_fn(w, y)
loss_mean += loss.item()
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
# assign new first values to w and y
w = wi.clone()
y = yi.clone()
else:
# if not update step, concat along batch dimension
w = torch.cat([w,wi.clone()],dim=0)
y = torch.cat([y,yi.clone()],dim=0)
In every loop, batch predictions wi
and labels yi
are collected in w
and y
by concatenating them along the batch dimension (dim=0
). Thus, after n_update steps, w
and y
first dimension’s will have size n_update*batch_size
. Then, every n_update
steps, the gradients are computed based on the entire batch of size n_update * batch_size
. Furthermore, after running a parameter update, I assign the new values for w
and y
as the current values of the batch, in order not to miss that batches information during training.
Executing this script yields the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an
inplace operation: [torch.cuda.FloatTensor [5]] is at version 3; expected version 2 instead.
Hint: the backtrace further above shows the operation that failed to compute its gradient. The
variable in question was changed in there or anywhere later. Good luck!
Initially, I ran the code without the .clone()
-suffixes and without retain_graph=True
; this did not work. I read in other posts that these fixes might solve the problem, but it did not work either. The problem seems to be that I perform an inplace somewhere, however my model()
is free of inplace operations (I write for instance x = x + a
instead of x+=a
). Therefore, I suspect that this is related to the structure of the training loop (maybe the concatenation?). Strangely, the training loop does not even run for n_update=1
, which should recover the original training loop, where updates are performed as usual after each batch.
I appreciate any help on this! Thanks!
Best, JZ