Custom autograd function with parameters

Hello, I am writing a custom function that inherits from torch.autograd.Function. This function uses some parameter(like pdrop for dropout, for example) which is needed to compute gradients.

class DummyFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, p):
        ctx.save_for_backward(input)
        return p * (5 * input ** 3 - 3 * input)

    @staticmethod
    def backward(ctx, grad_output, p): # not sure I can put p in the inputs
        input, = ctx.saved_tensors
        # number of gradients is required to match the number of inputs
        # even if they are not parameters with "requires_grad" flag
        return grad_output * 3* p * (5 * input ** 2 - 1), None


dummy_function = DummyFunction.apply

Then I write a module for this function(I assume this is the only way to access p in backward call without saving p in context?)

class DummyModule(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p

    def forward(self, input):
        return dummy_function(input, self.p)

    def backward(self, grad_outputs):
        return DummyFunction.backward(grad_outputs, self.p) # not sure

Should I call to DummyFunction.backward(grad_outputs, self.p) here? Seems legit, but I keep wandering if I will break some torch logic here. After all, the documentation pages insist on using DummyFunction.apply for forward call instead of DummyFunction.forward.

Wouldn’t storing p in the ctx via ctx.p = p work (assuming p is not a tensor)?

hmm, I didn’t consider it. This could be a solution. Thank you, I will try it.
Is it the way pdrop is accessed during backprop of torch.nn.functional.dropout? I looked through backward graph and discovered that dropout saves only boolean mask, so it is not clear where it gets pdrop to normalize this mask with / (1-pdrop) during backward. Does it really just get it from ctx.pdrop ?
What if I have several calls to this dummy_function during forward. Each ctx object will be different so I will not override my p parameter with later calls?

PyTorch scales with 1 / (1 - p) during the training here which allows it to avoid the scaling during inference.
This can also be seen in the outputs:

x = torch.ones(2, 2)
drop = nn.Dropout(p=0.5)
print(1 / (1 - 0.5))
# 2.0

out = drop(x)
print(out)
# tensor([[0., 2.],
#         [2., 0.]])

drop.eval()
out = drop(x)
print(out)
# tensor([[1., 1.],
#         [1., 1.]])

drop = nn.Dropout(p=0.2)

print(1 / (1 - 0.2))
# 1.25

out = drop(x)
print(out)
# tensor([[1.2500, 1.2500],
#         [0.0000, 0.0000]])

drop.eval()
out = drop(x)
print(out)
# tensor([[1., 1.],
#         [1., 1.]])

If you want to recompute the mask instead of storing it you would also need to take care of the used seed. Otherwise resampling the mask would generate a new one.

Yes, this should be the case as seen here:

class MyFun(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, p):
        ctx.save_for_backward(input)
        ctx.p = p
        print(f"p in forward {p}")
        return input * p
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        p = ctx.p
        print(f"p in backward {p}")
        return grad_output * input * p, None
    

fun = MyFun.apply

x = torch.randn(2, 2, requires_grad=True)
out1 = fun(x, 0.1)
# p in forward 0.1

out2 = fun(x, 0.5)
# p in forward 0.5

out1.mean().backward()
# p in backward 0.1

out2.mean().backward()
# p in backward 0.5
1 Like

Thanks, your answers were very helpful! I was able to solve my problem.
Only one moment is unclear about dropout. The scaling, as you said, is performed during forward which means that to calculate gradients we have to multiply the loss gradient with 1/(1-pdrop) after applying mask. Where does that pdrop come from during backward?

I believe the backward formula is implemented in derivatives.yaml:

- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
  input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))"
  result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t"

Thank you for all the answers!