Trying to upgrade parameters every n_update batches raises error

Hey,

I have a train loop where I would like to update the parameters just every n_update batches. I cannot just increase batch_size.

My current code looks like this:

y = torch.Tensor().to(device)
w = torch.Tensor().to(device)
loss_mean = 0
for batch,(X,yi,_) in enumerate(dataloader):
        
    # device
    X  = X.float().to(device)
    yi = yi.float().to(device)
    
    # send batch through network
    wi = model(X)

    if (batch>0) and (batch%n_update==0) or (batch==len(dataloader)-1):

        # compute loss and backpropagate
        loss = loss_fn(w, y)
        loss_mean += loss.item()
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        # assign new first values to w and y
        w = wi.clone()
        y = yi.clone()
    else:
        # if not update step, concat along batch dimension
        w = torch.cat([w,wi.clone()],dim=0)
        y = torch.cat([y,yi.clone()],dim=0)

In every loop, batch predictions wi and labels yi are collected in w and y by concatenating them along the batch dimension (dim=0). Thus, after n_update steps, w and y first dimension’s will have size n_update*batch_size. Then, every n_update steps, the gradients are computed based on the entire batch of size n_update * batch_size. Furthermore, after running a parameter update, I assign the new values for w and y as the current values of the batch, in order not to miss that batches information during training.

Executing this script yields the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an
inplace operation: [torch.cuda.FloatTensor [5]] is at version 3; expected version 2 instead.
Hint: the backtrace further above shows the operation that failed to compute its gradient. The
variable in question was changed in there or anywhere later. Good luck!

Initially, I ran the code without the .clone()-suffixes and without retain_graph=True; this did not work. I read in other posts that these fixes might solve the problem, but it did not work either. The problem seems to be that I perform an inplace somewhere, however my model() is free of inplace operations (I write for instance x = x + a instead of x+=a). Therefore, I suspect that this is related to the structure of the training loop (maybe the concatenation?). Strangely, the training loop does not even run for n_update=1, which should recover the original training loop, where updates are performed as usual after each batch.

I appreciate any help on this! Thanks!

Best, JZ

The issue is most likely raised by using retain_graph=True as it would keep the computation graph with the stale forward activations from previous iterations as the parameters were already updated.
Could you explain why you are using this argument as it’s usually not needed and causes these types or errors?

Hey,

thanks for the reply!
I was using it, because in another post I read it might help. However, I just tried it experimentally. I actually also ran the code without retain_graph, and it did not work either. As you say, I also suspect now that it is related to overwriting variables from previous loop iterations, and then trying to compute the gradient on them when they are already overwritten. For this to work, I think I will have to rearrange the order that things are computed in the loop. I am working on a minimum code example right now.

Best, JZ

So, here is the example. I tried to keep it as short as possible, with the necessary steps included. A dummy model is created. The loss function is actually the loss function that I use in my full code.

[If you don’t care about my application, skip directly to code example below :D]
In my actuall application, each batch consists of shifted windows of the same time series. The loss function thus makes use of only the last point of the batch. However, I need the entire batch, in order to perform batch_normalization of the most recent window in the batch based on mean and std statistics based on past windows in the batch. If I used the entire batches individual outputs to evaluate the loss, this would introduce forward leakage into the training process (because of mean and std statistics computed on all batch samples, which would let information from future samples flow to the normalization of earlier samples). This is why I can only use the last value of each batch ([-1] in loss_fn, which corresponds to the most recent time window) for evaluating the loss. This is okay, though, as long as the batch_size is small, because then I have many estimates of the loss. This also motivates that I try to update every n_update batches, because otherwise, calling the optimizer on each batch is really time-consuming. I actually wanted to make this a question in a separate post; but if you can recommend better techniques for time series batch normalization? I definitely need batchnorm, as otherwise, covariate shift messes up my gradients.

[Example Code]
So here is the example. I actually changed the approach. Instead of first collecting the batch information using concatenate and then computing the entire loss of the n_update*batch_size sized tensors, instead, now, I first compute the loss of each batch, add up the losses and compute the gradient update, followed by the optimizer reset. This leads to a much more compact code, which now works, and should be correct, right?

batch_size = 36
n_update   = 5

# tiny model class
class Test(nn.Module): 
    
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(6)
        return 
    
    def forward(self,x): 
        return self.bn(x)
    
# weird loss function
def Loss(w,y):
    # w: batch_size x 6 x 1
    # y: batch_size x 6 x 6
    # matmul returns 1 x 1
    return torch.exp(-torch.matmul(y[-1].sum(dim=0,keepdim=True),w[-1])).squeeze()    

# random values for X and y
_Xi = torch.rand((batch_size,6,1))
_yi = torch.rand((batch_size,batch_size,6))

# model and loss
model = Test().to(device)
loss_fn = Loss

optimizer = torch.optim.Adam(params=model.parameters())

model.train()

y = torch.Tensor().to(device)
w = torch.Tensor().to(device)
loss_mean = 0; loss_opt = torch.Tensor([0]).to(device)
for batch in range(11):
        
    # device
    Xi = _Xi.float().to(device)
    yi = _yi.float().to(device)

    # send batch through network
    wi = model(Xi)

    # compute batch loss
    loss       = loss_fn(wi, yi)
    loss_opt   = loss_opt + loss # not inplace
    loss_mean += loss.item()

    # if update step
    if (batch>0) and (batch%n_update==0):
        
        # backpropagate
        loss_opt.backward()
        optimizer.step()
        
        # gradient reset now after the optimizer update 
        optimizer.zero_grad() 
        
        # reset cumulative loss
        loss_opt = torch.Tensor([0]).to(device)

Best, JZ