Converting a Variable to a Parameter

(Quilby) #1

I’d like to build a CNN whose conv filters are dependent on the input (similar to dynamic filter networks )

I modified the cifar10 example, but I noticed that if I try to access a conv layers weight, and copy a Variable into it I cant (line 52).
What is the correct way to do this?

(Also notice that I’ve changed batch size to 1, is there a way to do this with bigger batches?)

Thanks a lot.

(Adam Paszke) #2

I think it would be much simpler with the functional interface. Just call F.conv2d(input, weight) where weight is generated by some other part of the network. This should work with arbitrary batches, just be careful about providing correct dimensions of the weights.

(Quilby) #3

Thanks a lot apaszke, I was not aware of these “functionals”.

Is this the correct way to do this convolution in a batch manner? (it runs very slow…)

    y = self.pool(F.relu(self.conv1(y)))
    z = Variable(torch.Tensor(x.size()[0], 16, 10, 10))

    for i in range(x.size()[0]):
        z[i,:]= F.conv2d(y[i,:].unsqueeze(0), x[i,:]).squeeze(0)

    z = self.pool(F.relu(z))
    z = z.view(-1, 16*5*5)

(x contains the convolutional weights, y contains the image I want to convolve over, and z is where I put the result)


(Adam Paszke) #4

Why can’t you compute the convolution with all filters in one go?

(Francisco Massa) #5

I have the impression that he wants to have a different convolutional weight per batch element. @Quilby is that right?

(Quilby) #6

Exactly, that is the point of the dynamic filter network. The conv filters for a certain picture are a function of that picture.

(Adam Paszke) #7

You could still separate the convolutions using groups. It’s not going to be super fast, but you could give that a try.

(Quilby) #8

Inspired by this answer on SO, I tried converting my code to use the conv3d functional but it gives me a weird error.

This is the code:
z = F.conv3d(y.unsqueeze(0), x )

The sizes are:

x: torch.Size([4, 16, 6, 5, 5])
y: torch.Size([4, 6, 14, 14])

The error I get is


(Adam Paszke) #9

Your input has invalid size. For these weights ((out_channels, in_channels, kT, kH, kW)), it should be 1x16x6x14x14. You probably swapped the in and out channels weight dimensions.