Why I just add sub() and abs() before mul() and 10x the cuda memory consuming in requires_grad mode?

I am trying to replace Convolution ops with Similarity ops defined by Deep SimNet
which change


It just add l1 norm along with the multiply in convolution. The size of t is same with w.

It looks like this in code:

# N, F, C, H, W, HH, WW, H_out, W_out are:
# batch_size, out_channel, in_channel, in_height, in_weight, kernel_height, kernel_weight, out_height, out_weight
x = Im2Col.apply(x, self.kernel_size, self.dilation, self.padding, self.stride)
x = x.unsqueeze(1) # x.shape = N, 1, C*HH*WW, H_out*W_out
w = self.weight.view(1, F, C*HH*WW, 1) # w.shape = 1, F, C*HH*WW, 1
t = self.tamplate.view_as(w) # t.shape = w.shape

where x, w, t are input, weight and template in the equation above

The trivial convolution looks like:

x = w.mul(x).sum(-2)
x = x.view(N,F,H_out, W_out)

which takes 700 M in torch.no_grad() and 1200 M in training phase. Similar with torch.nn.Conv2d. (although the torch.nn.Conv2d use fft while I use im2col here)

But the Similarity looks like:

x = x.sub(t).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)

which takes 1500M in torch.no_grad() and over 10000M(10G!) in training phase.
Why?, I just add sub(t).abs() before mul(w). And it make 10x number of gradient?

BTW, the net I use is resnet18. And all the Conv2d layers are replaced by Similarity layers.


  • nn.Conv2d does not use fft as far as I know, it use im2col.
  • I would be very surprised that your hand made conv takes exactly as much memory as the conv2d module, you might want to check your test code: your code creates several (~10) intermediary results that all need to be stored, while the conv2d module optimize out all these intermediary results and don’t save them.
  • Given the number of extra intermediary results that you add, and the size of the net your considering, I wouldn’t be surprise if it blows up the memory. You are doing around a x10 in the memory consumption (ignoring the fc layers at the end) in grad mode.

Thank you for your reply
But the conv op written by my hand does not blow up the memory,

x = w.mul(x).sum(-2)
x = x.view(N,F,H_out, W_out)

it works just fine. Larger than the nn.Conv2d, yes but not larger so much.
test again it’s 940M v.s 1500M in nn.Conv2d and my conv.
And the result has x1e-4 difference. : (sim(x)-conv(x)).sum() < 1e-4 * decay_rate if x = torch.rand(2,64,300,300) * decay_rate
Yes, my conv has some difference with the nn.Conv2d(I don’t know why), but the result is similar.

And I just added two operators in the code above QAQ!!

x = x.sub(t).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)


Aha, I just changed the sub(t) to sub(0) in there

x = x.sub(0).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)

it works just fine, similar memory with my conv
So, it must be the sub(t) caused this problem
BTW, t is a parameter with the same size with w