Arbitrary grouped convolutions

I am trying to implement a more general version of grouped convolutions that takes a list split whose sum is equal to the number of input channels channels and internally groups the convolutions according to this partition. This is in constrast to PyTorch’s grouped convolutions, which only allows groups of the same size. Here is my attempt (the important part is the forward, please ignore the references to DilatedCNN in the __init__) :

class SplitDilatedCNN(nn.Module):
    def __init__(self,channels=4,depth=5,kernel=5,split=[3,1]):
        super(SplitDilatedCNN, self).__init__()

        self.channels = channels
        self.depth = depth
        self.kernel = kernel
        self.split = split
        self.channelsAux = 0

        self.dcnn = nn.ModuleList()
        for i in range(len(split)):
            self.dcnn.append(DilatedCNN(split[i],split[i],self.channelsAux,self.depth,self.channels,self.kernel))

    def forward(self, x):
        x = x.split(self.split,1)
        x = [self.dcnn[j](x[j]) for j in range(len(self.split))]
        return torch.cat(x,1)

The forward pass seems to work, but in the backward pass I get an

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

and have absolutely no idea why. I can’t see any in-place operations!?

Your code seems to work, if I change DilatedCNN to nn.Conv2d:

class SplitDilatedCNN(nn.Module):
    def __init__(self,channels=4,depth=5,kernel=5,split=[3,1]):
        super(SplitDilatedCNN, self).__init__()

        self.channels = channels
        self.depth = depth
        self.kernel = kernel
        self.split = split
        self.channelsAux = 0

        self.dcnn = nn.ModuleList()
        for i in range(len(split)):
            self.dcnn.append(nn.Conv2d(
                                in_channels=split[i],
                                out_channels=split[i],
                                kernel_size=3,
                                stride=1,
                                padding=1))

    def forward(self, x):
        x = x.split(self.split,1)
        x = [self.dcnn[j](x[j]) for j in range(len(self.split))]
        return torch.cat(x,1)


model = SplitDilatedCNN()
x = torch.randn(1, 4, 24, 24)
output = model(x)
output.mean().backward()
print(model.dcnn[0].weight.grad)

Could you post or check the implementation of DilatedCNN?
Maybe the in-place operation is performed there.

Thanks for your reply! Here’s the code for DilatedCNN. This code works independently of SplitDilatedCNN, so there shouldn’t be any errors there. Could the nested ModuleList's be a problem?

class DilatedCNN(nn.Module):
    def __init__(self,channelsIn,channelsOut,channelsAux=0,depth=5,baseFilters=16,kernel=5):
        super(DilatedCNN, self).__init__()

        self.channelsIn = channelsIn+channelsAux
        self.channelsOut = channelsOut
        self.channelsAux = channelsAux

        self.baseFilters = baseFilters
        self.depth = depth
        self.kernel = kernel

        self.convolutions = nn.ModuleList()
        for i in range(self.depth):
            if i==0:
                self.convolutions.append(ConvSame(self.channelsIn,self.baseFilters,self.kernel,i+1,False))
            else:
                self.convolutions.append(ConvSame(int((2**(i-1))*self.baseFilters),int((2**i)*self.baseFilters),self.kernel,i+1,False))
        for i in reversed(range(self.depth)):
            if i==0:
                self.convolutions.append(ConvSame(self.baseFilters,self.channelsOut,self.kernel,i+1,False))
            else:
                self.convolutions.append(ConvSame(int((2**i)*self.baseFilters),int((2**(i-1))*self.baseFilters),self.kernel,i+1,False))

    def forward(self, x):
        for i, module in enumerate(self.convolutions):
            if i<len(self.convolutions)-1:
                x = module(x,True)
            else:
                x = module(x,False)
        return x

And since DilatedCNN references ConvSame:

class ConvSame(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, dilation, grouped):
        super(ConvSame, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.multiplier = self.out_channels/self.in_channels
        if grouped: self.groups = self.in_channels if self.multiplier>1 else self.out_channels
        else:       self.groups = 1
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel,stride=1,padding=dilation*(kernel-1)//2,dilation=dilation,groups=self.groups)

    def forward(self, x, addRelu=True):
        if addRelu: return F.relu(self.conv(x))
        else:       return self.conv(x)

I tried it with the complete code and it still works:

model = SplitDilatedCNN()
x = torch.randn(1, 4, 24, 24)
output = model(x)
output.mean().backward()
print(model.dcnn[0].convolutions[0].conv.weight.grad)

Which PyTorch version are you using?

Yes, I just tried the same and, in isolation, your example works for me, too. However, this is just a small part of my computation graph. Weirdly, the backward pass through the entire computation graph works if I replace this submodule with a conventional convolution. At the same time, the submodule itself also doesn’t seem to be the problem…

Edit: Even with DilatedCNN the backward pass through the entire graph works fine, but as soon as I switch to SplitDilatedCNN it crashes with the originally mentioned error. Are there any weird interactions between what happens inside a module and what happens before/after in the computation graph?

Edit2: Even replacing the forward pass in SplitDilatedCNN with

def forward(self, x):
        x = x.split(self.split,1)
        x = torch.cat(x,1)
        return x

throws the error. It’s super strange!

[PyTorch 0.4.0]

I tried it with 0.4.0 and a master build and both work in this isolated example.
Since it’s still not working in your code, could you post the code or link to a repo so that I could have a look?

I really appreciate your effort! Unfortunately, I am not allowed to share more than this snippet (wish I could). Let me know if you can think of any circumstances that would cause this behavior (I know that cannot be more than random guesses without the full code)…

OK, I understand.
Is the code running successfully without the split and cat in your last post?

Yes, a simple pass-through works:

def forward(self, x):
        return x

If you just return x without any operations in SplitDilatedCNN your code works?
This seems to be strange, as x does not require gradients and you didn’t create any other tensors which might require gradients.
Your code should throw an error like:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Could you check it?

It does throw that error in the isolated example that you posted above but not in the full framework. But I think I’ve identified the issue. What I’m actually doing outside of SplitDilatedCNN resembles this:

class Dummy(nn.Module):

    def __init__(self,sdcnn):
        super(Dummy, self).__init__()
        self.sdcnn = sdcnn

    def foo(self,a):
        b = self.sdcnn(a)
        a[:,:,10,10] = b[:,:,10,10]
        return a

    def forward(self,x):
        x = self.foo(x)
        return x

x = torch.randn(1, 4, 24, 24,device='cuda:0')
f = Dummy(SplitDilatedCNN()).to(device='cuda:0')
y = f(x)
y.mean().backward()

However, now x==y. What’s happening? Surely, a and b are different, so at least I would expect x[:,:,10,10] != y[:,:,10,10].