Hi,
given a nn block “B” an input image “x” is forwarded like so: x = B(x) for T iterations.
This is in contrast to the more common setting x = B_T(…(B_2(B_1(x))…), where all B_i’s share architecture, but not weights.
I am using the former, with an additional constrain: I need access to all mid-outputs B(x), B(B(x), etc…
I implemented this via a for loop, and noticed it is slower and consumes more memory during training than the “common case”. How can I alleviate this problem?
Could you post (pseudo) code, which would show the difference in the performance?
If I understand your use case correctly, you are comparing these approaches:
# approach 1
model = nn.Linear(1, 1)
x = torch.randn(1, 1)
for _ in range(100):
x = model(x)
# approach 2
models = nn.ModuleList([nn.Linear(1, 1) for _ in range(100)])
x = torch.randn(1, 1)
for m in models:
x = m(x)