Hi everyone,
I have two models, say model_a, model_b.
model_b takes as input the output of model_a as well as some other additional input.
For each input to model_a there are several (large amount) of additional inputs to model_b.
I want to make a gradient step for model_b’s weights for each of the additional inputs while accumulating gradients on model_a’s weights and only making a gradient step after all additional inputs for model_b have been forwarded.
I want to achieve something like this:
examples = 10
additional_input_per_example = 100
batch_size = 32
hidden_dim = 3
for i in range(examples):
x1 = torch.randn(batch_size, hidden_dim)
enc = model_a(x1) # Assume this takes 1 hour
for j in range(additional_input_per_example):
x2 = torch.randn(batch_size, hidden_dim)
y = torch.randn(batch_size, hidden_dim)
y_hat = model_b(enc, x2) # Assume this takes 5 seconds
loss = criterion(y_hat, y)
loss.backward()
optimizer_b.step()
optimizer_b.zero_grad()
optimizer_a.step()
optimizer_a.zero_grad()
If I move the calculation of “emb” into the inner loop it works, but is highly inefficient due to recomputation of “emb”.
Is using the option retain_graph=True absolutelly essential?
This code works when using this option, but in my actual case this increases the memory consumption several times over and seems unnecessary, since I only need to accumulate gradients in model_a weights and I am not using the graph for anything else.
Thanks in advance for your help