Can we wrap existing Tensor Functions by a Module Class


Currently when we define a model we have the flexibility of using a torch.flatten or torch.nn.Flatten.

A user may choose to use either of them. I agree that torch.nn.Flatten is a module class and uses torch.flatten inside the forward call.

The advantage torch.nn.Flatten offers is that it has capability of forward hooks.

What I would like to do is either register a forward hook to the Tensor or wrap the tensor in nn.Module Class so I can register a hook to it.

tensor.flatten() is just an example. Ideally I would want to wrap all the torch_function to the nn.Module .

This should not change the way user defines the model. Any suggestions would be helpful.

does feel like this is related to quantization, can you add a different tag?

I might have to repost this as it is not letting me edit the post. The only logical label seems to be uncategorized.

1 Like