Normalize gradient in backward pass in single layer

(Bartosz Ludwiczuk) #1

I write a nn.Module with my own layer, say it is just nn.BatchNorm.
I would like to normalize the gradient in the backward pass before feeding it to lower layers (so it have only information about the direction). I would like to normalize only that single layer , not gradient from all layers.
How could I do this without writing own full backward pass and just by using Autograd gradients, which then could be normalized?

I imagine it like that:

class MyBatchNorm(nn.Module):
   def __init__(self, in_features):
       self.fc1_bn = nn.BatchNorm1d(in_features)

   def forward(self, x):
       # normalize data by BN
       out = self.fc1_bn(x)
       return out 

  def backward(self, grad_output):
       grad_input, grad_weight, grad_bias =  self.fc1_bn.backward(grad_output)
       return F.normalize(grad_input), F.normalize(grad_weight), F.normalize(grad_bias)


Have you figured out how? Does your proposal work?