Hello pytorchers!
I am implementing an encoder-decoder style network and my decoder needs to access the output from an intermediate encoder layer. The problem is that I have wrapped this layer in a Sequential container.
Now, I could decompose the Sequential container and retrieve my output in the forward method, but I wanted to avoid this for code consistency.
After some thought, I came up with a simple ‘capture’ layer that tracks the last input fed:
class Capture(nn.Module):
''' Identity module used to capture input '''
def __init__(self):
super().__init__()
self.storage = None
def forward(self, x):
self.storage = x.clone()
return x
def get(self):
''' Returns stored capture '''
return self.storage
class MyNet(nn.Module):
def forward(self, x):
...
intermediate_output = self.capture_layer.get()
Would it be semantically correct to use something like this? This obviously breaks if there is concurrent access on my module, but is that something I should be worried about?
From my understanding, autograd also requires captures some ‘state’ but is that state tracked through the variable or the layer itself? (I guess its the former).
Alternatively, if there is an easier way to achieve what I want, then that works too!
Thank you for your time!