-
model:
class Net(ReparamModule):
supported_dims = {28, 32}def init(self, state):
if state.dropout:
raise ValueError(“LeNet doesn’t support dropout”)
super(LeNet, self).init()
self.conv1 = nn.Conv2d(state.nc, 6, 5, padding=2 if state.input_size == 28 else 0)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 1 if state.num_classes <= 2 else state.num_classes)def forward(self, x):
out = F.relu(self.conv1(x), inplace=True)
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out), inplace=True)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out), inplace=True)
out = F.relu(self.fc2(out), inplace=True)
out = self.fc3(out)
return out -
base model
class PatchModules(type):
def call(cls, state, *args, **kwargs):
r""“Called when you call ReparamModule(…) “””
net = type.call(cls, state, *args, **kwargs)
# collect weight (module, name) pairs
# flatten weights
w_modules_names = []
for m in net.modules():
for n, p in m.named_parameters(recurse=False):
if p is not None:
w_modules_names.append((m, n))
for n, b in m.named_buffers(recurse=False):
if b is not None:
logging.warning((
'{} contains buffer {}. The buffer will be treated as '
'a constant and assumed not to change during gradient '
'steps. If this assumption is violated (e.g., '
'BatchNorm*d\'s running_mean/var), the computation will '
'be incorrect.').format(m.__class__.__name__, n))
net._weights_module_names = tuple(w_modules_names)
# Put to correct device before we do stuff on parameters
net = net.to(state.device)
ws = tuple(m._parameters[n].detach() for m, n in w_modules_names)
assert len(set(w.dtype for w in ws)) == 1
# reparam to a single flat parameter
net._weights_numels = tuple(w.numel() for w in ws)
net._weights_shapes = tuple(w.shape for w in ws)
with torch.no_grad():
flat_w = torch.cat([w.reshape(-1) for w in ws], 0)
# remove old parameters, assign the names as buffers
for m, n in net._weights_module_names:
delattr(m, n)
m.register_buffer(n, None)
# register the flat one
net.register_parameter('flat_w', nn.Parameter(flat_w, requires_grad=True))
return net
@add_metaclass(PatchModules)
class ReparamModule(nn.Module):
def _apply(self, *args, **kwargs):
rv = super(ReparamModule, self)._apply(*args, **kwargs)
return rv
def get_param(self, clone=False):
if clone:
return self.flat_w.detach().clone().requires_grad_(self.flat_w.requires_grad)
return self.flat_w
@contextmanager
def unflatten_weight(self, flat_w):
ws = (t.view(s) for (t, s) in zip(flat_w.split(self._weights_numels), self._weights_shapes))
for (m, n), w in zip(self._weights_module_names, ws):
setattr(m, n, w)
yield
for m, n in self._weights_module_names:
setattr(m, n, None)
def forward_with_param(self, inp, new_w):
with self.unflatten_weight(new_w):
return nn.Module.__call__(self, inp)
def __call__(self, inp):
return self.forward_with_param(inp, self.flat_w)
model = torch.nn.DataParallel(Net, device_ids=device_ids) not work !!!
Similar question: https://github.com/SsnL/dataset-distillation/issues/23