I am doing metric learning and a single iteration of the training loop looks as follows:

batch = get_new_batch()
batchwise_descriptors = learnable_network_1(batch)
total_loss = 0
for i in range(batch.shape[0]):
for j in range(i+1, batch.shape[0]):
pseudo_distances = learnable_network_2(
batchwise_descriptors[i],
batchwise_descriptors[j]
)
loss = loss_fn(pseudo_distances)
total_loss += loss
total_loss.backward()
optim.step()

Or in words: I am performing batched operations (CNN) in learnable_network_1 and then unbatched pairwise operations in learnable_network_2 as loss_fn. This unbatched step is the memory bottleneck because the pseudo distance matrices are very big, so I end up having to run with batch size of 2 and run multiple steps of backward per single optim.step(). This is highly non-optimal because in general my gradient estimates benefit from the quadratic behavior of the pairwise loop: the bigger batch, the better gradient estimates per single run of the CNN.

One way to address that would be to just call backward on each loss and not just at total_loss at the end, however that would mean backpropagating through learnable_network_1 quadratically many times.

An alternative, which I don’t know how to implement, is to backpropagate in two steps: in the loop, I would call some backpropagate_until_batchwise_descriptors(loss) which treats batchwise_descriptors as leaf variables and just accumulates the gradient in them (without going into the CNN) and then, after the pairwise iteration is over, I’d call backpropagate_through_cnn(batchwise_descriptors.grad) to find CNN gradients. This approach would mean I only need to store a single pseudo_distances matrix in memory at any given time, greatly reducing the memory strain of this operation.

Unless I’m very mistaken this two-step backpropagation is mathematically correct and I think it should be possible to implement, though I have no idea how difficult it is and at which point I’d need to integrate with autograd. Could I ask for some input on that?

My first question would be why can’t you run the second network in a batched manner? You could create all the combinations of batch by doing something like: batchwise_descriptors.unsqueeze(0).expand(full_size) for the first one and batchwise_descriptors.unsqueeze(1).expand(full_size) for the second one. Where full_size is (batch_size, batch_size, your_other_dims...).
That would definitely be the most efficient in terms of speed.

Not if you have hard constraints that you cannot do that (or memory issues), you can do the following:
(I haven’t tested the code so there might be a typo, but this is the spirit).

batch = get_new_batch()
optim.zero_grad() # Make sure all .grads are 0
batchwise_descriptors = learnable_network_1(batch)
# Create a new leaf Tensor that will accumulate all the grads
# coming from net 2
net_2_in = batchwise_descriptors.detach().requires_grad_()
total_loss = 0
for i in range(batch.shape[0]):
for j in range(i+1, batch.shape[0]):
pseudo_distances = learnable_network_2(
net_2_in[i],
net_2_in[j]
)
local_loss = loss_fn(pseudo_distances)
local_loss.backward() # Accumulate in in learnable_network_2 and net_2_in gradients
total_loss += loss.detach() # Just used for bookeeping, no gradient should flow back
# Now backwarprop the whole thing in net 1
batchwise_descriptors.backward(net_2_in.grad)
optim.step()

Why can’t I do it in a batched way? That’s because I hid a lot of complexity in this example, in reality I extract a random number of descriptors from each batch element so the pseudo-distance matrices are of different shapes from each other (which would require a huge amount of padding) and even if that worked, the loss_fn uses external, non-batchable operations.

Ok sounds good!
Then the second example should do what you want
Note that you can play with the tradeof between memory and speed by accumulating a bunch of losses in the double for loop before calling backward. Doing so will use more memory but will speed things up a little bit.
Like for example accumulate the loss for the j loop and call backward once at the end.