# Why I just add sub() and abs() before mul() and 10x the cuda memory consuming in requires_grad mode?

(Neo Li) #1

I am trying to replace Convolution ops with Similarity ops defined by Deep SimNet
which change
$y[j]=\sum&space;x[i]w[j-i]$
to
$y[j]=\sum&space;|x[i]-t[j-i]|_{l_1}w[j-i]$

It just add l1 norm along with the multiply in convolution. The size of t is same with w.

It looks like this in code:

# N, F, C, H, W, HH, WW, H_out, W_out are:
# batch_size, out_channel, in_channel, in_height, in_weight, kernel_height, kernel_weight, out_height, out_weight
x = Im2Col.apply(x, self.kernel_size, self.dilation, self.padding, self.stride)
x = x.unsqueeze(1) # x.shape = N, 1, C*HH*WW, H_out*W_out
w = self.weight.view(1, F, C*HH*WW, 1) # w.shape = 1, F, C*HH*WW, 1
t = self.tamplate.view_as(w) # t.shape = w.shape


where x, w, t are input, weight and template in the equation above

The trivial convolution looks like:

x = w.mul(x).sum(-2)
x = x.view(N,F,H_out, W_out)


which takes 700 M in torch.no_grad() and 1200 M in training phase. Similar with torch.nn.Conv2d. (although the torch.nn.Conv2d use fft while I use im2col here)

But the Similarity looks like:

x = x.sub(t).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)


which takes 1500M in torch.no_grad() and over 10000M(10G!) in training phase.
Why?, I just add sub(t).abs() before mul(w). And it make 10x number of gradient?

BTW, the net I use is resnet18. And all the Conv2d layers are replaced by Similarity layers.

(Alban D) #2

Hi,

• nn.Conv2d does not use fft as far as I know, it use im2col.
• I would be very surprised that your hand made conv takes exactly as much memory as the conv2d module, you might want to check your test code: your code creates several (~10) intermediary results that all need to be stored, while the conv2d module optimize out all these intermediary results and don’t save them.
• Given the number of extra intermediary results that you add, and the size of the net your considering, I wouldn’t be surprise if it blows up the memory. You are doing around a x10 in the memory consumption (ignoring the fc layers at the end) in grad mode.

(Neo Li) #3

But the conv op written by my hand does not blow up the memory,

x = w.mul(x).sum(-2)
x = x.view(N,F,H_out, W_out)


it works just fine. Larger than the nn.Conv2d, yes but not larger so much.
test again it’s 940M v.s 1500M in nn.Conv2d and my conv.
And the result has x1e-4 difference. : (sim(x)-conv(x)).sum() < 1e-4 * decay_rate if x = torch.rand(2,64,300,300) * decay_rate
Yes, my conv has some difference with the nn.Conv2d(I don’t know why), but the result is similar.

And I just added two operators in the code above QAQ!!

x = x.sub(t).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)


Aha, I just changed the sub(t) to sub(0) in there

x = x.sub(0).abs().mul(w).sum(-2))
x = x.view(N,F,H_out, W_out)


it works just fine, similar memory with my conv
So, it must be the sub(t) caused this problem
BTW, t is a parameter with the same size with w