Hello!!
I have the following hook similar to weight normalization, it does a linear transformation of the parameters of the network based on some pre-trained weights (source),
This is the hook
class AffineSource(object):
"""
Multiply current network with source information before every forward pass
"""
def __init__(self, name: str, source_param: torch.Tensor):
self.name = name
self.source_param = source_param
def integrate_source_weights(self, module, only_delta):
"""
element wise multiplication between current parameters and source
"""
source = getattr(module, self.name + '_source')
delta = getattr(module, self.name + '_delta')
gamma = getattr(module, self.name + '_gamma')
if only_delta:
computation = delta
else:
computation = delta + (gamma * source)
return computation
@staticmethod
def apply(module: Module, name: str, source_param):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, AffineSource) and hook.name == name:
raise RuntimeError("Cannot register two source params hooks on "
"the same parameter {}".format(name))
fn = AffineSource(name, source_param)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_source', Parameter(source_param, requires_grad=False))
module.register_parameter(name + '_delta', Parameter(torch.zeros_like(weight), requires_grad=True))
module.register_parameter(name + '_gamma', Parameter(torch.ones_like(weight), requires_grad=True))
setattr(module, name, fn.integrate_source_weights(module, module.only_delta))
module.register_forward_pre_hook(fn)
return fn
def remove(self, module: Module) -> None:
weight = self.integrate_source_weights(module, module.only_delta)
delattr(module, self.name)
del module._parameters[self.name + '_source']
del module._parameters[self.name + '_delta']
del module._parameters[self.name + '_gamma']
setattr(module, self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.integrate_source_weights(module, module.only_delta))
when applying it to the model weights. I create the flag only_delta
in every module of the model with
def get_delta_model(model, mode=True):
"""
Create a flag in every module to use a different behavior of the hook
"""
for module in model.modules():
module.only_delta = mode
return model
after activating/deactivating module.only_delta
during training when feeding the data:
model = get_delta_model(model, True)
outputs = model(img)
model = get_delta_model(model, False)
ouputs2 = model(img)
it gives me this error:
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
without the only_delta options in the hooks class it works, only when I activate and deactivate later only_delta
in every module it gives me the error
I think the problem is the return in integrate_source_weights
I guess the output should be a parameter. (I am not sure, I followed the code of weight normalization pytorch/weight_norm.py at master · pytorch/pytorch · GitHub)
I would appreciate any help!