Hello, I am writing a custom function that inherits from torch.autograd.Function
. This function uses some parameter(like pdrop
for dropout, for example) which is needed to compute gradients.
class DummyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, p):
ctx.save_for_backward(input)
return p * (5 * input ** 3 - 3 * input)
@staticmethod
def backward(ctx, grad_output, p): # not sure I can put p in the inputs
input, = ctx.saved_tensors
# number of gradients is required to match the number of inputs
# even if they are not parameters with "requires_grad" flag
return grad_output * 3* p * (5 * input ** 2 - 1), None
dummy_function = DummyFunction.apply
Then I write a module for this function(I assume this is the only way to access p in backward call without saving p in context?)
class DummyModule(nn.Module):
def __init__(self, p):
super().__init__()
self.p = p
def forward(self, input):
return dummy_function(input, self.p)
def backward(self, grad_outputs):
return DummyFunction.backward(grad_outputs, self.p) # not sure
Should I call to DummyFunction.backward(grad_outputs, self.p)
here? Seems legit, but I keep wandering if I will break some torch logic here. After all, the documentation pages insist on using DummyFunction.apply
for forward call instead of DummyFunction.forward
.
Wouldn’t storing p
in the ctx
via ctx.p = p
work (assuming p
is not a tensor)?
hmm, I didn’t consider it. This could be a solution. Thank you, I will try it.
Is it the way pdrop is accessed during backprop of torch.nn.functional.dropout? I looked through backward graph and discovered that dropout saves only boolean mask, so it is not clear where it gets pdrop to normalize this mask with / (1-pdrop) during backward. Does it really just get it from ctx.pdrop ?
What if I have several calls to this dummy_function during forward. Each ctx object will be different so I will not override my p parameter with later calls?
PyTorch scales with 1 / (1 - p)
during the training here which allows it to avoid the scaling during inference.
This can also be seen in the outputs:
x = torch.ones(2, 2)
drop = nn.Dropout(p=0.5)
print(1 / (1 - 0.5))
# 2.0
out = drop(x)
print(out)
# tensor([[0., 2.],
# [2., 0.]])
drop.eval()
out = drop(x)
print(out)
# tensor([[1., 1.],
# [1., 1.]])
drop = nn.Dropout(p=0.2)
print(1 / (1 - 0.2))
# 1.25
out = drop(x)
print(out)
# tensor([[1.2500, 1.2500],
# [0.0000, 0.0000]])
drop.eval()
out = drop(x)
print(out)
# tensor([[1., 1.],
# [1., 1.]])
If you want to recompute the mask
instead of storing it you would also need to take care of the used seed. Otherwise resampling the mask
would generate a new one.
Yes, this should be the case as seen here:
class MyFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input, p):
ctx.save_for_backward(input)
ctx.p = p
print(f"p in forward {p}")
return input * p
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
p = ctx.p
print(f"p in backward {p}")
return grad_output * input * p, None
fun = MyFun.apply
x = torch.randn(2, 2, requires_grad=True)
out1 = fun(x, 0.1)
# p in forward 0.1
out2 = fun(x, 0.5)
# p in forward 0.5
out1.mean().backward()
# p in backward 0.1
out2.mean().backward()
# p in backward 0.5
1 Like
Thanks, your answers were very helpful! I was able to solve my problem.
Only one moment is unclear about dropout. The scaling, as you said, is performed during forward which means that to calculate gradients we have to multiply the loss gradient with 1/(1-pdrop)
after applying mask. Where does that pdrop
come from during backward?
I believe the backward formula is implemented in derivatives.yaml
:
- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))"
result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t"
Thank you for all the answers!