I am trying to train a model that requires ‘mini-batches of batches’ (something like columns of each row of a table ). Mini-batch is each row. Batch within a mini-batch are the column values for that row. However, the error should be back-propagated once per mini-batch (row). So, it needs a sum over the losses for each column within a row and then back-propagate.
For this design, I observe that PyTorch requires ‘retain_graph’ to be ‘True’ from mini-batch to mini-batch (row to row). This eventually leads to out-of-memory error for the GPU. How can I track the ‘Variables’ that are forcing ‘retain_graph’ to be ‘True’? How can I circumvent this issue?