I want to add a new layer, which I want to control the gradient, but I’m not sure whether Function or Module to use, pseudo code are as below:
class MyModule(Function or nn.Module???) def __init__(self, flag): self.flag = flag def forward(self, x): some operations on x def backward(self, y): if flag == True: automatically compute the gradient if flag == False: some operations on the gradient of this layer and then continue to backward
I have read some code, if I use nn.Module, then I can’t write backward function? Then flag can’t control gradients?
If I write Function, can I save my flag variable during training and use it in backward?