Why most torchvision.transforms inherit nn.Module?

From the source code, I found nearly all of the torchvision.transforms inherit nn.Module.

But in the official tutorial on custom transforms we can create the custom transform from scratch (without the inheritance from nn.Module).

I found in the source code, even those transforms are created from scratch (without the inheritance from nn.Module), it seems that these transforms will still work well.

So why most torchvision.transforms inherit nn.Module? And do we need to inherit nn.Module if we build our own custom transforms?

They inherit from nn.Module for two key reasons;

  1. It makes the transforms runnable on GPUs.
  2. It makes them compatible with TorchScript.

Unless either of these are important to your specific application, you can safely implement your custom transforms without inheriting from nn.Module.

1 Like
class MyNormalize:
    def __call__(self, img):
        if type(img) == np.ndarray:
            mean, std = np.mean(img), np.std(img)
        elif type(img) == torch.Tensor:
            mean, std = torch.mean(img), torch.std(img)
        img = img - mean
        img = img / std
        return img

This is the transform I created for myself. This custom transform can still work on GPUs. So I donnot think the first reason is important.