Capture layer to track intermediate results

Hello pytorchers!

I am implementing an encoder-decoder style network and my decoder needs to access the output from an intermediate encoder layer. The problem is that I have wrapped this layer in a Sequential container.

Now, I could decompose the Sequential container and retrieve my output in the forward method, but I wanted to avoid this for code consistency.

After some thought, I came up with a simple ‘capture’ layer that tracks the last input fed:

class Capture(nn.Module):
    ''' Identity module used to capture input '''

    def __init__(self):
        super().__init__()
        self.storage = None

    def forward(self, x):
        self.storage = x.clone()
        return x

    def get(self):
        ''' Returns stored capture '''
        return self.storage
class MyNet(nn.Module):

    def forward(self, x):
        ...
        intermediate_output = self.capture_layer.get()

Would it be semantically correct to use something like this? This obviously breaks if there is concurrent access on my module, but is that something I should be worried about?

From my understanding, autograd also requires captures some ‘state’ but is that state tracked through the variable or the layer itself? (I guess its the former).

Alternatively, if there is an easier way to achieve what I want, then that works too!

Thank you for your time!

Very likely the .clone() wastes memory unless you have inplace operations immediately following the capture.

While I have used similar hacks for debugging, I must admit that I do believe that just spelling out the forward is probably just as easy and more explicit (as in the Zen of Python).

Also, the capture_layer would hang on to storage even after the other bits go out of scope. That could be tricky, too.

It probably is possible to do this if you know you’ve got a lid on several passes through the layer interacting.

Is it clever? Likely.

But is it good style? I’d say no.

Best regards

Thomas

All valid points! To add to that, the resulting invocation is also not exactly a ‘pure’ function. I’ve split it for now since it seems like a better design.

The reason I was a bit hesitant was because the split is in my encoder’s feature extractor (Xception), but the intermediate output is required much later in my decoder (DeepLab V3+), and I wanted to make these modules as independent as possible, to avoid having to refer two papers while reviewing either module in the future.