Alright, so I’ve dug further into this and found some interesting things. TL; DR: I solved it for my use-case, but I might have stumbled onto some bugs. Don’t use in-place operations, bad things happen.
I managed to achieve what I was going for by creating the main weight parameter, w, and two intermediate variables,
W, a zeros-filled of the desired final shape and
M, A mask of the same shape as W filled with 1s that correspond which elements of W I want to fill with w. For consistency the number of positive elements in M has to equal the number of elements of w, of course.
I initialize both W and M in the init method of the module, then call W[M] = w during the forward() method, and convolve using the modified W. This works, but for some reason when I do things this way it results in training starting out fast, then progressively slowing down (starting out at around 5 batches/s and dropping to 0.2 batches/s over the course of the first 1000 batches). It also throws an error about non-leaf variables not yet being serializable when I train to use torch.save, presumably because I’m creating some naughty nodes that I shouldn’t be.
My initial suspicion was that I was creating additional subgraphs that weren’t being deleted, or not freeing memory appropriately (more on that in a moment), but the memory usage in this case was constant. Investigating further, I found that if I replaced W[M] = w with W.masked_copy_(M,w), I get an error after the first batch saying that I need to use"retain_variables=True" if I want to backpropagate a second time through the graph. The error method is a bit confusing here, as I am only calling backward() once.
My intuition is that the above error occurs because I’m calling an in-place method in forward(), which seems to be against best pytorch practice at the moment, so whatever variables autograd would need to do the backprop aren’t getting saved. Calling backward(retain_variables=True) results in the same behavior as using W[M]=w; it works, but it progressively slows down throughout training. I’m still not sure on what’s causing the slowing–my best guess is that some part of the graph isn’t getting appropriately freed in such a way that rather than creating multiple subgraphs that take up memory, it’s just propping through the same graph element an increasing number of times on each successive iteration.
I ran into another interesting issue while messing with this–while running Brendan Amos’s Densenet model inside my own training code, if I swapped out standard filters for dilated filters using the dilate keyword in conv2d, I would see a memory explosion that would overflow my GPU within ~50 batches. It turned out this was because I was using saved_loss+=loss rather than +=loss.data, so it was creating multiple copies of subgraphs and not freeing them appropriately. The user error isn’t interesting, but the fact that when using undilated filters I do not observe this memory explosion, despite the bad +=loss line, is interesting.
Anyhow, I did manage to get things working for my use-case–rather than trying to do a masked copy or in-place operation, I just instantiate W as a full-rank tensor and drop W*M into the F.conv2d call. Works great and is about twice as fast as using the dilation parameter (presumably because it allows for the use of the cuDNN backend).
Here’s a code snippet with which I’m currently getting ~80-100% speedup over using the dilation keyword for my use-case. Note that this currently prevents you from saving with torch.save due to a “can’t serialize non-leaf variables yet” dealio.
in init:
self.m = Variable(torch.cuda.FloatTensor( [[(( [( ([1]+[0]*(dilation-1)) *3)[:-(dilation-1)]] + [[0]*(3+(dilation-1)*2)]*(dilation-1))*3) [:-(dilation-1)]] *n_in]*n_out))
self.W = Parameter(torch.zeros(n_out,n_in,3+2*(dilation-1),3+2*(dilation-1)),requires_grad=True).cuda() # requires_grad not necessarily necessary
in forward:
out = F.conv2d(input,weight=self.W*self.m,padding=dilation,bias=None)
Sorry for the wall of text, but hopefully this will prove enlightening and thorough if other people come along with similar issues on in-place ops. For the record, I’m using the build provided by conda install and am on python 2.7 (my attempts at building from source crash, sigh). I tested these with Cuda7.5 on a GTX980 and Cuda8.0 on a Titan X.
Best,
Andy