RuntimeError: Trying to backward through the graph a second time... and other errors

I’d appreciate it if someone could shed some light on this.

I’ve encountered this problem when trying to train my model (policy gradient for asset allocation. Hence the terms below weights and rebalance and pv for portfolio value), but anytime I try to solve the issue I stumble upon another error:

If it is of any help:

  • data_loader is a custom class instance for retrieving batches.
  • pvm is buffer (torch tensor) where the weights vectors w are saved. At each iteration, I retrieve some of them, and then replace them by the output of the model.
  • env.get_rewards(pv_prev, pv_next) = torch.log(pv_next / pv_prev)

Training loop:

for t in range(batch_size + n_batches, train_data_length):
    for t_last in range(t-n_batches, t):
        
        opt.zero_grad()
        
        X_batch, y_batch = data_loader.get_batch(t_last)
        
        w_prev = pvm.get_weights(t_last)
        
        w_opt = model(X_batch, w_prev)
        
        pvm.update_weights(w_opt, t_last)
        
        pv_prev, pv_next = env.rebalance(y_batch, w_prev, w_opt, t_last)
        
        rewards = env.get_rewards(pv_prev, pv_next)
        
        loss = -torch.mean(rewards)
        
        print(loss.data.numpy())
        
        loss.backward()
        
        opt.step()

As it is, at the second iteration it throws RuntimeError: Trying to backward through the graph a second time...

There some posts suggesting to:

  1. change loss.backward() to loss.backward(retain_graph=True) (here)
  1. use .detach on the output of the network (if I’m not mistaken), that is model(X_batch, w_prev).detach() (here)

If I use the first approach, I get RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 21, 1, 1]] is at version 2.... The only tensor having 21 in one of its dimensions is the first concat variable of the network in the forward pass:

def forward(self, X_batch, w_batch):
        
        w_stocks = w_batch[:, 1:, :]
        
        out = F.relu(self.cnn1(X_batch))
        out = F.relu(self.cnn2(out))
        concat = torch.cat([out, w_stocks.unsqueeze(dim=1)], dim=1)
        out = self.cnn3(concat).squeeze(dim=1)
        
        cash_bias = torch.zeros(self.batch_size, 1, 1)
        
        concat = torch.cat([cash_bias, out], dim=1)
        
        out = F.softmax(concat, dim=1)
        return out

if I use the second one, then the backward pass will stop working.

maybe use,

loss.backward(create_graph=True)