I saw similar topic but I don’t get why my code contains an inplace operation.
code:
l = self.backward()
with:
def backward(self):
transitions = self.replay_memory.sample(self.batch_size)
batch = Transition(*zip(*transitions))
a_batch = torch.cat(batch.action).to(self.device) # [BS x 1]
cs_batch = torch.stack(batch.state).to(self.device) # [BS x state_size]
ns_batch = torch.stack(batch.next_state).to(self.device) # [BS x state_size]
r_batch = torch.tensor(np.expand_dims(np.array(batch.reward),1), dtype=torch.float).to(self.device) # [BS x 1]
# STEP 2: PREDICTIONS
pred = self.calc_pred(cs_batch, a_batch)
# STEP 3: TARGETS
with torch.no_grad():
target = self.calc_target(ns_batch, r_batch)
# STEP 4: LOSS
loss = self.calc_loss(pred, target) # [BS x 1]
loss = loss.mean()
# STEP 5: TRAIN
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
With
def calc_pred(self, cs_batch, a_batch):
if self.noisy:
self.network.sample_noise()
pred = self.network(cs_batch).gather(1, a_batch)
return pred
And also:
def calc_target(self, ns_batch, r_batch):
if self.double:
if self.noisy:
self.network.sample_noise()
self.target_network.sample_noise()
next_Q = self.network(ns_batch)
max_a = next_Q.max(1)[1].unsqueeze(1)
Q_target = self.target_network(ns_batch).gather(1, max_a)
target = r_batch + (self.gamma ** self.multi_step_n) * Q_target
return target
and loss function:
def calc_loss(self, pred, target):
loss = F.mse_loss(pred, target, reduction='none')
return loss
This works if am not using a noisy network. However, when using noisy (self.noisy = True) it returns the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5]] is at version 1003; expected version 1002 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
However, I do not change any variables in my code whenever I use my noisy network? I have used the noisy network also in another piece of code were I do not get the error, so I do not think the error is in that part of the code.