How to implement a forward pre hook that has a different behavior depending on a flag in a nn.module

Hello!!

I have the following hook similar to weight normalization, it does a linear transformation of the parameters of the network based on some pre-trained weights (source),
This is the hook

class AffineSource(object):
    """
    Multiply current network with source information before every forward pass
    """
    def __init__(self, name: str, source_param: torch.Tensor):
        self.name = name
        self.source_param = source_param

    def integrate_source_weights(self, module, only_delta):
        """
        element wise multiplication between current parameters and source
        """
        source = getattr(module, self.name + '_source')
        delta = getattr(module, self.name + '_delta')
        gamma = getattr(module, self.name + '_gamma')
        if only_delta:
            computation = delta
        else:
            computation = delta + (gamma * source)
        return computation

    @staticmethod
    def apply(module: Module, name: str, source_param):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, AffineSource) and hook.name == name:
                raise RuntimeError("Cannot register two source params hooks on "
                                   "the same parameter {}".format(name))

        fn = AffineSource(name, source_param)
        weight = getattr(module, name)

        del module._parameters[name]
        module.register_parameter(name + '_source', Parameter(source_param, requires_grad=False))
        module.register_parameter(name + '_delta', Parameter(torch.zeros_like(weight), requires_grad=True))
        module.register_parameter(name + '_gamma', Parameter(torch.ones_like(weight), requires_grad=True))
        setattr(module, name, fn.integrate_source_weights(module, module.only_delta))
        module.register_forward_pre_hook(fn)
        return fn

    def remove(self, module: Module) -> None:
        weight = self.integrate_source_weights(module, module.only_delta)
        delattr(module, self.name)
        del module._parameters[self.name + '_source']
        del module._parameters[self.name + '_delta']
        del module._parameters[self.name + '_gamma']
        setattr(module, self.name, Parameter(weight.data))

    def __call__(self, module, inputs):
        setattr(module, self.name, self.integrate_source_weights(module, module.only_delta))

when applying it to the model weights. I create the flag only_delta in every module of the model with

def get_delta_model(model, mode=True):
    """
    Create a flag in every module to use a different behavior of the hook
    """
    for module in model.modules():
        module.only_delta = mode
    return model

after activating/deactivating module.only_delta during training when feeding the data:

model = get_delta_model(model, True)
outputs = model(img)

model = get_delta_model(model, False)
ouputs2 = model(img)

it gives me this error:

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

without the only_delta options in the hooks class it works, only when I activate and deactivate later only_delta in every module it gives me the error

I think the problem is the return in integrate_source_weights I guess the output should be a parameter. (I am not sure, I followed the code of weight normalization pytorch/weight_norm.py at master · pytorch/pytorch · GitHub)

I would appreciate any help! :slight_smile:

I managed to solve it using computation = torch.clone(delta) in integrate_source_weights however I am not entirely sure why it doesn’t work just returning delta