I rewrote the forward method of Net and found that torch.nn.DataParallel can't be used. Who has a good solution. Thank you very much!

  1. model:
    class Net(ReparamModule):
    supported_dims = {28, 32}

    def init(self, state):
    if state.dropout:
    raise ValueError(“LeNet doesn’t support dropout”)
    super(LeNet, self).init()
    self.conv1 = nn.Conv2d(state.nc, 6, 5, padding=2 if state.input_size == 28 else 0)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 1 if state.num_classes <= 2 else state.num_classes)

    def forward(self, x):
    out = F.relu(self.conv1(x), inplace=True)
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out), inplace=True)
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out), inplace=True)
    out = F.relu(self.fc2(out), inplace=True)
    out = self.fc3(out)
    return out

  2. base model

class PatchModules(type):
def call(cls, state, *args, **kwargs):
r""“Called when you call ReparamModule(…) “””
net = type.call(cls, state, *args, **kwargs)

    # collect weight (module, name) pairs
    # flatten weights
    w_modules_names = []

    for m in net.modules():
        for n, p in m.named_parameters(recurse=False):
            if p is not None:
                w_modules_names.append((m, n))
        for n, b in m.named_buffers(recurse=False):
            if b is not None:
                logging.warning((
                    '{} contains buffer {}. The buffer will be treated as '
                    'a constant and assumed not to change during gradient '
                    'steps. If this assumption is violated (e.g., '
                    'BatchNorm*d\'s running_mean/var), the computation will '
                    'be incorrect.').format(m.__class__.__name__, n))

    net._weights_module_names = tuple(w_modules_names)

    # Put to correct device before we do stuff on parameters
    net = net.to(state.device)

    ws = tuple(m._parameters[n].detach() for m, n in w_modules_names)

    assert len(set(w.dtype for w in ws)) == 1

    # reparam to a single flat parameter
    net._weights_numels = tuple(w.numel() for w in ws)
    net._weights_shapes = tuple(w.shape for w in ws)
    with torch.no_grad():
        flat_w = torch.cat([w.reshape(-1) for w in ws], 0)

    # remove old parameters, assign the names as buffers
    for m, n in net._weights_module_names:
        delattr(m, n)
        m.register_buffer(n, None)

    # register the flat one
    net.register_parameter('flat_w', nn.Parameter(flat_w, requires_grad=True))

    return net

@add_metaclass(PatchModules)
class ReparamModule(nn.Module):
def _apply(self, *args, **kwargs):
rv = super(ReparamModule, self)._apply(*args, **kwargs)
return rv

def get_param(self, clone=False):
    if clone:
        return self.flat_w.detach().clone().requires_grad_(self.flat_w.requires_grad)
    return self.flat_w

@contextmanager
def unflatten_weight(self, flat_w):
    ws = (t.view(s) for (t, s) in zip(flat_w.split(self._weights_numels), self._weights_shapes))
    for (m, n), w in zip(self._weights_module_names, ws):
        setattr(m, n, w)
    yield
    for m, n in self._weights_module_names:
        setattr(m, n, None)

def forward_with_param(self, inp, new_w):
    with self.unflatten_weight(new_w):
        return nn.Module.__call__(self, inp)

def __call__(self, inp):
    return self.forward_with_param(inp, self.flat_w)

model = torch.nn.DataParallel(Net, device_ids=device_ids) not work !!!

Similar question: https://github.com/SsnL/dataset-distillation/issues/23

I might not fully understand your use case, but nn.DataParallel would most likely depend on classes derived from nn.Module while you seem to create a custom base class.

1 Like

Thank you very much for your reply. Thanks. PyTorch custom forward function does not work with DataParallel. This problem seems to be difficult to solve.