I’d appreciate it if someone could shed some light on this.
I’ve encountered this problem when trying to train my model (policy gradient for asset allocation. Hence the terms below weights and rebalance and pv
for portfolio value), but anytime I try to solve the issue I stumble upon another error:
If it is of any help:
-
data_loader
is a custom class instance for retrieving batches. -
pvm
is buffer (torch tensor) where the weights vectorsw
are saved. At each iteration, I retrieve some of them, and then replace them by the output of the model. env.get_rewards(pv_prev, pv_next) = torch.log(pv_next / pv_prev)
Training loop:
for t in range(batch_size + n_batches, train_data_length):
for t_last in range(t-n_batches, t):
opt.zero_grad()
X_batch, y_batch = data_loader.get_batch(t_last)
w_prev = pvm.get_weights(t_last)
w_opt = model(X_batch, w_prev)
pvm.update_weights(w_opt, t_last)
pv_prev, pv_next = env.rebalance(y_batch, w_prev, w_opt, t_last)
rewards = env.get_rewards(pv_prev, pv_next)
loss = -torch.mean(rewards)
print(loss.data.numpy())
loss.backward()
opt.step()
As it is, at the second iteration it throws RuntimeError: Trying to backward through the graph a second time...
There some posts suggesting to:
- change
loss.backward()
toloss.backward(retain_graph=True)
(here)
- use
.detach
on the output of the network (if I’m not mistaken), that ismodel(X_batch, w_prev).detach()
(here)
If I use the first approach, I get RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 21, 1, 1]] is at version 2...
. The only tensor having 21
in one of its dimensions is the first concat
variable of the network in the forward pass:
def forward(self, X_batch, w_batch):
w_stocks = w_batch[:, 1:, :]
out = F.relu(self.cnn1(X_batch))
out = F.relu(self.cnn2(out))
concat = torch.cat([out, w_stocks.unsqueeze(dim=1)], dim=1)
out = self.cnn3(concat).squeeze(dim=1)
cash_bias = torch.zeros(self.batch_size, 1, 1)
concat = torch.cat([cash_bias, out], dim=1)
out = F.softmax(concat, dim=1)
return out
if I use the second one, then the backward pass will stop working.