I am working on a big project which for which I need to call manual_backward
and optimizer.step
inside a loop for every batch.
Here is some reference code for a training_function that works, and another that doesn’t:
def loss_fn_working(self, batch: Any, batch_idx: int):
env = self.envs[self.p]
actions = None
prev_log_rewards = torch.empty(env.done.shape[0]).type_as(env.state)
prev_forward_logprob = None
loss = torch.tensor(0.0, requires_grad=True)
TERM = env.terminal_index
while not torch.all(env.done):
active = ~env.done
forward_logprob, back_logprob = self.forward(env)
log_rewards = -self.get_rewards()
if actions is not None:
error = log_rewards - prev_log_rewards[active]
error += back_logprob.gather(1, actions[actions != TERM, None]).squeeze(1)
error += prev_forward_logprob[active, -1]
error -= forward_logprob[:, -1].detach()
error -= (
prev_forward_logprob[active]
.gather(1, actions[actions != TERM, None])
.squeeze(1)
)
loss = loss + F.huber_loss(
error,
torch.zeros_like(error),
delta=1.0,
reduction="none",
)
loss = loss * log_rewards.softmax(0)
loss = loss.mean(0)
actions = self.sample_actions(forward_logprob, active, TERM)
env.step(actions)
# save previous log-probs and log-rewards
if prev_forward_logprob is None:
prev_forward_logprob = torch.empty_like(forward_logprob)
prev_forward_logprob[active] = forward_logprob
prev_log_rewards[active] = log_rewards
return loss, log_rewards
def calculate_loss(
self,
loss,
log_rewards,
prev_log_rewards,
back_log_prob,
actions,
stop_prob,
prefix, # Added for debugging
prev_stop_prob=None,
prev_forward_log_prob=None,
):
error = torch.tensor(0.0, requires_grad=True) + log_rewards - prev_log_rewards # [B]
error = error + (back_log_prob).gather(1, actions.unsqueeze(1)).squeeze(1) # P_B(s|s')
error = error - stop_prob.detach() # P(s_f|s')
if prev_stop_prob is not None and prev_forward_log_prob is not None:
error = error + prev_stop_prob.detach() # P(s_f|s)
error = error - (prev_forward_log_prob).gather(
1, actions.unsqueeze(1)
).squeeze(1)
loss = loss + F.huber_loss( # accumulate losses
error,
torch.zeros_like(error),
delta=1.0,
reduction="none",
)
loss = loss * log_rewards.softmax(0)
return loss.mean(0)
def loss_fn_not_working(self, batch, batch_size, prefix, batch_idx):
gfn_opt, rep_opt = self.optimizers()
# some code here
losses = []
rep_losses = []
prev_forward_log_prob = None
prev_stop_prob = torch.zeros(batch_size, device='cuda')
loss = torch.tensor(0.0, requires_grad=True, device='cuda')
active = torch.ones((batch_size,), dtype=bool, device='cuda')
graph = torch.diag_embed(torch.ones(batch_size, self.n_dim)).cuda()
while active.any():
graph_hat = graph[active].clone()
adj_mat = graph_hat.clone()
rep_loss, latent_var = self.rep_model(torch.cat((adj_mat, next_id.unsqueeze(-1)), axis = -1))
rep_loss_tensor = torch.tensor(0.0, requires_grad=True) + rep_loss
forward_log_prob, Fs_masked, back_log_prob, next_prob, stop_prob = (
self.gfn_model(latent_var)
)
with torch.no_grad():
actions = self.sample_actions(Fs_masked)
graph = self.update_graph(actions)
#######################
log_rewards = -self.energy_model(graph_hat, batch, False, self.current_epoch)
if counter==0:
loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix)
else:
loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix, prev_stop_prob[active], prev_forward_log_prob[active])
losses.append(loss.item())
rep_losses.append(rep_loss.item())
if prefix == 'train':
rep_opt.zero_grad()
self.manual_backward(rep_loss_tensor, retain_graph=True)
rep_opt.step()
gfn_opt.zero_grad()
self.manual_backward(loss)
self.clip_gradients(gfn_opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm") # NEEDED??
gfn_opt.step()
with torch.no_grad():
active[indices_to_deactivate] = False #active updated appropriately
indices = indices[~current_stop]
# active_indices = ~current_stop # Not being used?
next_id = F.one_hot(indices, num_classes=self.n_dim)
prev_log_rewards = log_rewards[~current_stop]
counter += 1
if prev_forward_log_prob is None:
prev_forward_log_prob = torch.empty_like(forward_log_prob)
prev_forward_log_prob[active] = forward_log_prob[~current_stop]
prev_stop_prob[active] = stop_prob[~current_stop]
return losses, graph, log_rewards, counter, rep_losses
Here, the main variable of importance is prev_forward_log_prob
in loss_fn_not_working
. The loss is being calculated using calculate_loss() function.
I have kept manual_optimization as True.
When using loss_fn_not_working, and keeping retain_graph as false for loss
, I get the following error:
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
If I do keep retain_graph as True for loss
(i.e. the loss for the second optimizer), I get the following error instead:
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 10]], which is output 0 of AsStridedBackward0, is at version 3; expected version 1 instead.
If I use loss_fn_working, there is no problem. So, I understand that the problem arises when using backward calls inside the loop. I am not really making any in-place operations, so how can I make the 2nd loss_fn work? I tried cloning the relevant variables but it doesn’t work until we detach them, which I can’t do.