ReLU + Dropout inplace

(Vadim Kantorov) #1

I’ve tried to chain ReLU and Dropout, both in place:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(3, 1, 1)
        self.relu = nn.ReLU(inplace = True)
        self.dropout = nn.Dropout(inplace = True)

    def forward(self, x):
        return self.dropout(self.relu(self.conv(x))).sum()

model = Net()

model(torch.autograd.Variable(torch.FloatTentsor(1, 3, 16, 16).cuda().uniform_())).backward()

This fails with: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

It seems one could still compute the gradient of ReLU even if Dropout was applied inplace after, since dropout is just a multiplication by a positive number and doesn’t change the ReLU gating mask.

One can of course write a simple module for doing it in a combined way, but I was wondering your thoughts on expressing this in PyTorch (say, by disabling dirty checking if a module is marked by some special attribute) and possibility of fusion when JIT arrives?

(Riddhiman Dasgupta) #2

@vadimkantorov I just tried this, and it ran without any errors for me on both CPU and GPU. I copied your code over verbatim. I am on PyTorch version 0.3.0.post4

(Vadim Kantorov) #3

Mine is 0.4.0a0+4ae0579

(Vadim Kantorov) #4

@apaszke Is it a regression then?

(Adam Paszke) #5

So the check is triggered because we don’t consider those “special cases” and I don’t think we will want to. It would complicate the logic too much and slow autograd down. Not sure about it wasn’t failing in 0.3, could be a regression, could be a necessary check that was added only later.

(Vadim Kantorov) #6

Sure, supporting constructs like ReLU + Dropout case-by-case is not worth it, especially if it slows everything down. I was thinking of a generic Module base class or module attribute that would disable dirty checking within that subgraph if a user wishes so.

(Adam Paszke) #7

That’s an interesting idea, but it really is a gamble. You shouldn’t assume anything about the state the library functions retain for backward, so your code could work just fine in one version, and be silently broken in another one. I think it’s safer to implement such things as a “fused” autograd function yourself. You don’t want to waste weeks of experimentation to discover bugs like these only later.