Let’s say I want to create a model agnostic wrapper like this:
from torch import Tensor
from torch.nn import Module
class Classifier(Module):
def __init__(self, model: Module, criterion: Module, optimizer: Optimizer):
super().__init__()
self.epoch = 0
self.model = model
self.criterion = criterion
self.optimizer = optimizer
def forward(self, input: Tensor) -> Tensor:
return self.model(input)
def loss(self, output: Tensor, target: Tensor) -> Tensor:
return self.criterion(output, target)
def fit(self, input: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
self.optimizer.zero_grad()
output = self(input)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
return output, loss
def evaluate(self, input: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
output = self(input)
loss = self.loss(output, target)
return output, loss
However let’s say I need distributed training. How should I approach to this with distributed data parallel? I found that data parallel don’t work well with custom methods in a module, however here I can pass a ddp module instance instead of a simple module for my classifier’s main nn, will this be ok? Or should i make my whole classifier a ddp? or should I remove the inheritance from Module in Classifier?