I’ve been trying to build a U-Net lately and I wanted to stick with the nn.Module way of thinking, to make skip connection modules.
After some thinking, the following snippet helped me with my troubles and I felt like sharing it.
It’s a general-purpose Module that wraps python lambdas
class LambdaModule(nn.Module): def __init__(self, lambd): super().__init__() import types assert type(lambd) is types.LambdaType self.lambd = lambd def forward(self, x): return self.lambd(x)
The way I used it was the following:
middle_lambda = lambda x: torch.cat([x, self.lower(x)], 1) middle = LambdaModule(middle_lambda) block = [left, middle, right] self.block = nn.Sequential(*block)
where left and right are the modules left and right of the skip connection module