Wrapping a DDP module inside a simple Module

Let’s say I want to create a model agnostic wrapper like this:

from torch import Tensor
from torch.nn import Module

class Classifier(Module):
          def __init__(self, model: Module, criterion: Module, optimizer: Optimizer):
              super().__init__()
              self.epoch = 0
              self.model = model
              self.criterion = criterion
              self.optimizer = optimizer

          def forward(self, input: Tensor) -> Tensor:
              return self.model(input)
          
          def loss(self, output: Tensor, target: Tensor) -> Tensor:
              return self.criterion(output, target)

          def fit(self, input: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
              self.optimizer.zero_grad()
              output = self(input)
              loss = self.loss(output, target)
              loss.backward()
              self.optimizer.step()
              return output, loss

          def evaluate(self, input: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: 
              output = self(input)
              loss = self.loss(output, target)
              return output, loss

However let’s say I need distributed training. How should I approach to this with distributed data parallel? I found that data parallel don’t work well with custom methods in a module, however here I can pass a ddp module instance instead of a simple module for my classifier’s main nn, will this be ok? Or should i make my whole classifier a ddp? or should I remove the inheritance from Module in Classifier?