ReLU + Dropout inplace

I’ve tried to chain ReLU and Dropout, both in place:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(3, 1, 1)
        self.relu = nn.ReLU(inplace = True)
        self.dropout = nn.Dropout(inplace = True)

    def forward(self, x):
        return self.dropout(self.relu(self.conv(x))).sum()

model = Net()
model.cuda()
model.train()

model(torch.autograd.Variable(torch.FloatTentsor(1, 3, 16, 16).cuda().uniform_())).backward()

This fails with: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

It seems one could still compute the gradient of ReLU even if Dropout was applied inplace after, since dropout is just a multiplication by a positive number and doesn’t change the ReLU gating mask.

One can of course write a simple module for doing it in a combined way, but I was wondering your thoughts on expressing this in PyTorch (say, by disabling dirty checking if a module is marked by some special attribute) and possibility of fusion when JIT arrives?

@vadimkantorov I just tried this, and it ran without any errors for me on both CPU and GPU. I copied your code over verbatim. I am on PyTorch version 0.3.0.post4

Mine is 0.4.0a0+4ae0579

@apaszke Is it a regression then?

So the check is triggered because we don’t consider those “special cases” and I don’t think we will want to. It would complicate the logic too much and slow autograd down. Not sure about it wasn’t failing in 0.3, could be a regression, could be a necessary check that was added only later.

Sure, supporting constructs like ReLU + Dropout case-by-case is not worth it, especially if it slows everything down. I was thinking of a generic Module base class or module attribute that would disable dirty checking within that subgraph if a user wishes so.

That’s an interesting idea, but it really is a gamble. You shouldn’t assume anything about the state the library functions retain for backward, so your code could work just fine in one version, and be silently broken in another one. I think it’s safer to implement such things as a “fused” autograd function yourself. You don’t want to waste weeks of experimentation to discover bugs like these only later.

@apaszke Does the following look like a correct implementation?

If yes, do you think JIT / TorchScript could bring more performance improvements on CUDA (e.g. fuse this with Conv2d or BatchNorm1d)? I saw @ngimel used torch.rand instead of torch.bernoulli for better fusability, but does this apply to a custom autograd function?

class ReLUDropoutInplace(torch.nn.Module):
    def __init__(self, p):
        super(ReLUDropoutInplace, self).__init__()
        self.p = p

    def forward(self, input):
        if self.training:
            #p1m = 1. - self.p
            #mask = torch.rand_like(input) < p1m
            #mask *= (input > 0)
            #return mask.type_as(input) * input * (1./p1m)
            return ReLUDropoutInplace.ReLUDropoutInplaceFunction.apply(input, self.p)
        else:
            return input.clamp_(min = 0)

    class ReLUDropoutInplaceFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, p):
            ctx.p1m = 1. - p
            mask = torch.rand_like(input) < ctx.p1m
            mask *= (input > 0)
            one_minus_mask = 1 - mask
            ctx.save_for_backward(one_minus_mask)

            input = input.contiguous().masked_fill_(one_minus_mask, 0)
            input *= (1. / ctx.p1m)

            return input

        @staticmethod
        def backward(ctx, grad):
            one_minus_mask, = ctx.saved_tensors
            grad = grad.contiguous().masked_fill_(one_minus_mask, 0)
            grad *= 1. / ctx.p1m
            return grad, None

If I try the following, I’m getting RuntimeError: a leaf Variable that requires grad has been used in an in-place operation. at the masked_fill_ operation:

if self.training:
            p1m = 1. - self.p
            mask = torch.rand_like(input) < p1m
            mask *= (input > 0)
            one_minus_mask = 1 - mask
            input.masked_fill_(one_minus_mask, 0)
            input *= (1. / ctx.p1m)
            return input

...
if __name__ == '__main__':
    x = torch.rand(2, 16, 64, 128)
    x = (x + 1).requires_grad_()
    m = ReLUDropoutInplace(p = 0.2)
    m.train()

    y = m(x)
    y.sum().backward()

Just play around with various implementations, and use .graph_for to check if and how they get fused or not. Some patterns are recognized better than others. It would also be very helpful for us if you could document the various failed attempts, because then we can attempt to fix those.

1 Like

First bug filed :slight_smile: https://github.com/pytorch/pytorch/issues/22124

For the case with 1-mask and with @ngimel ’ s solution:

graph(%input : Float(*, *, *, *),
      %p : float,
      %2 : bool):
  %5 : int = prim::Constant[value=0]() # foo.py:13:30
  %4 : float = prim::Constant[value=1]() # foo.py:11:19
  %3 : Scalar? = prim::Constant()
  %6 : Tensor = prim::If(%2) # foo.py:10:9
    block0():
      %p1m : float = aten::sub(%4, %p) # foo.py:11:19
      %8 : Float(*, *, *, *) = aten::rand_like(%input) # foo.py:12:20
      %mask.1 : Byte(*, *, *, *) = aten::lt(%8, %p1m) # foo.py:12:20
      %10 : Byte(*, *, *, *) = aten::gt(%input, %5) # foo.py:13:22
      %mask : Tensor = aten::mul_(%mask.1, %10) # foo.py:13:13
      %14 : float = aten::div(%4, %p1m) # foo.py:14:51
      %17 : Tensor = prim::FusionGroup_0(%input, %mask, %14)
      -> (%17)
    block1():
      %16 : Tensor = aten::clamp_(%input, %5, %3) # foo.py:16:20
      -> (%16)
  return (%6)
with prim::FusionGroup_0 = graph(%3 : Float(*, *, *, *),
      %6 : Tensor,
      %1 : float):
  %7 : Tensor = aten::type_as(%6, %3) # foo.py:14:28
  %5 : Tensor = aten::mul(%3, %7) # foo.py:14:20
  %2 : Tensor = aten::mul(%5, %1) # foo.py:14:20
  return (%2)

So if inplace gen is supported in future, it’s pretty nice.

Also impossiblity of passing a constant parameter (p) to a constructor of torch.jit.ScriptModule's derived class is quite strange (it suggests adding to __constants__, but it’s not really a “constant”).

Since this thread has a lot of views, for people looking for impl, my current one is here: https://gist.github.com/vadimkantorov/360ece06de4fd2641fa9ed1085f76d48