A small snippet for lambda modules

Hi all,
I’ve been trying to build a U-Net lately and I wanted to stick with the nn.Module way of thinking, to make skip connection modules.
After some thinking, the following snippet helped me with my troubles and I felt like sharing it.
It’s a general-purpose Module that wraps python lambdas

class LambdaModule(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        import types
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

The way I used it was the following:

            middle_lambda = lambda x: torch.cat([x, self.lower(x)], 1)
            middle = LambdaModule(middle_lambda)
            block = [left, middle, right]
            self.block = nn.Sequential(*block)

where left and right are the modules left and right of the skip connection module

3 Likes

Very useful, thanks!