I was trying to wrap the layers with a generic wrapper and then perform the layer operation whatever it was, like here:
class TorchModuleWrapper(Module): def __init__(self, torch_object: Module): super().__init__() # save the layer in the _wrapped_object self._wrapped_object = torch_object def __call__(self, x): # feedforwad the input on the layer with the new weights return self._wrapped_object(x)
The problem happened when I tried to apply the
torch.autograd.Function on the layer weights inside the forward function since the function that I have to requires learnable parameters, these learnable parameters lost their gradients.
There’s one solution that I commented on in the example here that could be applying the operation between the input and the weights but it would not make the wrapper anymore.
So, is there any solution to apply the function on the weights without losing the gradient for the learnable parameters located outside the layer?
class PACT(Function): @staticmethod def forward(ctx, x, s): return torch.round(x * s) / s @staticmethod def backward(ctx, grad_output): dLdx = grad_output dLds = grad_output return dLdx, dLds class TorchModuleWrapper(Module): def __init__(self, torch_object: Module): super().__init__() self.bits = 8 torch_object.bias = None self._wrapped_object = torch_object self.scale = Parameter(Tensor(1)) def __call__(self, x): weight = self._wrapped_object.weight quantized_weight = PACT.apply(weight, self.scale) self._wrapped_object.weight = Parameter(quantized_weight) return self._wrapped_object(x) def __repr__(self): return "Wrapped" + self._wrapped_object.__repr__()
Collab example with the unneeded solution: