From the source code, I found nearly all of the torchvision.transforms inherit nn.Module.
But in the official tutorial on custom transforms we can create the custom transform from scratch (without the inheritance from nn.Module).
I found in the source code, even those transforms are created from scratch (without the inheritance from nn.Module), it seems that these transforms will still work well.
So why most torchvision.transforms inherit nn.Module? And do we need to inherit nn.Module if we build our own custom transforms?
They inherit from
nn.Module for two key reasons;
- It makes the transforms runnable on GPUs.
- It makes them compatible with TorchScript.
Unless either of these are important to your specific application, you can safely implement your custom transforms without inheriting from
def __call__(self, img):
if type(img) == np.ndarray:
mean, std = np.mean(img), np.std(img)
elif type(img) == torch.Tensor:
mean, std = torch.mean(img), torch.std(img)
img = img - mean
img = img / std
This is the transform I created for myself. This custom transform can still work on GPUs. So I donnot think the first reason is important.