class CustomNetwork(nn.Module):
def __init__(self, loss_function, architecture, latent_dim, output_dim, pre_trained=True):
super(CustomNetwork, self).__init__()
self.loss_function = loss_function
self.architecture = architecture
self.latent_dim = latent_dim
self.output_dim = output_dim
self.frontend = self.design_network(self.architecture)
self.output_layer = nn.Conv2d(self.latent_dim, self.output_dim, kernel_size=1)
I have defined the loss function by passing it as a parameter to the CustomNetwork
class. I wrapped the network with nn.DataParallel
module. But nn.DataParallel
says AttributeError: 'DataParallel' object has no attribute 'loss_function'
while the computing loss in the forward pass.