Hi all,
I’ve been trying to build a U-Net lately and I wanted to stick with the nn.Module way of thinking, to make skip connection modules.
After some thinking, the following snippet helped me with my troubles and I felt like sharing it.
It’s a general-purpose Module that wraps python lambdas
class LambdaModule(nn.Module):
def __init__(self, lambd):
super().__init__()
import types
assert type(lambd) is types.LambdaType
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
The way I used it was the following:
middle_lambda = lambda x: torch.cat([x, self.lower(x)], 1)
middle = LambdaModule(middle_lambda)
block = [left, middle, right]
self.block = nn.Sequential(*block)
where left and right are the modules left and right of the skip connection module