Network in for loop memory usage during training

given a nn block “B” an input image “x” is forwarded like so: x = B(x) for T iterations.

This is in contrast to the more common setting x = B_T(…(B_2(B_1(x))…), where all B_i’s share architecture, but not weights.

I am using the former, with an additional constrain: I need access to all mid-outputs B(x), B(B(x), etc…

I implemented this via a for loop, and noticed it is slower and consumes more memory during training than the “common case”. How can I alleviate this problem?


Thanks, Jonathan

Could you post (pseudo) code, which would show the difference in the performance?
If I understand your use case correctly, you are comparing these approaches:

# approach 1
model = nn.Linear(1, 1)
x = torch.randn(1, 1)

for _ in range(100):
    x = model(x)

# approach 2
models = nn.ModuleList([nn.Linear(1, 1) for _ in range(100)])
x = torch.randn(1, 1)
for m in models:
    x = m(x)

Yes, this is the case, only with AdaIn-like res-blocks instead of f.c.’s. Interestingly approach 1 consumes more memory during training.

How can I implement approach 1 more efficiently? Note again, that I need access to all mid-outputs

Thank you very much, Jonathan