When should one subclass nn.ModuleDict over nn.Module?
for example, here,
import trw.train
import trw.datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import collections
batch_size = 32
latent_size = 64
mnist_size = 28 * 28
hidden_size = 256
class Flatten(nn.Module):
def forward(self, i):
return i.view(i.shape[0], -1)
class Generator(nn.Module):
This file has been truncated. show original
is it better to use nn.ModuleDict whenever merging multiple neural networks?
so,
class Y(nn.ModuleDict):
def __init__(self):
super().__init__()
self['NetA'] = NetA()
self['NetB'] = NetB()
def X(nn.Module):
def __init__(self):
super().__init__()
self.modelone = NetA()
self.modeltwo = NetB()
which is the preferred way?
ptrblck
September 5, 2019, 3:50pm
#2
It might depend on your use case, but nn.ModuleDict
is just a container (dict
), which stores modules and is used to register these modules properly inside a parent nn.Module
.
Based on your code snippet, I would derive from nn.Module
(your X
class).