Combining `pytorch` native layers with custom non-`pytorch` ones wrapped by a `torch.nn.Module`

Hello,

as the title says I am trying to build a torch.nn.Sequential model that stacks native pytorch layers, e.g. torch.nn.Linear, with a custom layer that is not compatible with torch.Tensor, which runs for example on a device that does not support tensors.

Say for example, that I have wrapped this custom module:

class CustomModule(torch.nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # convert the inputs, for instance to numpy  
        x = x.detach().numpy()
        # here I would actually like to save the gradients before 
        # detaching and in some sense reattach them to the ouptuts of the
        # module

        # perform some operations
        x = do_some_non_pytorch_things(x)

        # convert back to pytorch
        return torch.as_tensor(x)

    def backward(self, out_grad: torch.Tensor) -> torch.Tensor:
         # here I would externally calculate the gradients of the 
         # non-`pytorch` operations and combine them with the gradients 
         # saved above
         gradients = custom_gradients(out_grad)
         return gradients

I would like to be able to run it inside of a model like the following:

model = torch.nn.Sequential(
    torch.nn.Linear(n, k),
    torch.nn.ReLU(),
    CustomModule(),
    torch.nn.Linear(i, j),
)

while preserving the flow of the gradients and making the CustomModule to completely mimic a standard pytorch layer. Is this possible by any means? What would be the best way to achieve this?

This might be related in some part to this other discussion.

Many thanks to anyone who will take time to look into this!

Hi @RupertSciamenna,

You need to define your custom function as a torch.autograd.Function object (docs here) and then you can just wrap that custom function object within a nn.Module object (like you have above, but without the backward method as that’ll be define in your torch.autograd.Function object.

With the ‘non-pytorch’ operations, you’ll need to manually derive the backward formula and define it in the torch.autograd.Function too.

1 Like

Thanks, thus even detaching the input tensor in the forward will not break the flow and the layers before that are still going to be trained correctly, right?

Assuming you’ve derived the correct derivative, it should be fine.