Add a new layer with a flag controls gradients?

I want to add a new layer, which I want to control the gradient, but I’m not sure whether Function or Module to use, pseudo code are as below:

class MyModule(Function or nn.Module???)
  def __init__(self, flag):
    self.flag = flag
  def forward(self, x):
    some operations on x
  def backward(self, y):
    if flag == True:
      automatically compute the gradient
    if flag == False:
      some operations on the gradient of this layer and then continue to backward

I have read some code, if I use nn.Module, then I can’t write backward function? Then flag can’t control gradients?

If I write Function, can I save my flag variable during training and use it in backward?

Hi,
I think you want the following:

class MyModule(nn.Module)
  def __init__(self, flag):
    self.flag = flag
  def forward(self, x):
    return MyFunc.apply(x, flag)

class MyFunc(Function):
  def forward(ctx, x, flag):
    ctx.flag = flag # Save your flag
    some operations on x
  def backward(ctx, y):
    if ctx.flag == True:
      automatically compute the gradient
    if ctx.flag == False:
      some operations on the gradient of this layer and then continue to backward

    return gradX, None # Here you want None because no grads are needed for flag.
2 Likes

Thanks a lot, it really helps~