[SOLVED] Pytorch conv2d consumes more gpu memory than tensorflow

Recently, I done some experiments on conv2d, following is the pytorch code:

import torch

conv = torch.nn.Conv2d(1, 256, 9)
xs = torch.rand((400, 1, 11, 36))


print('parameters:', sum(param.numel() for param in conv.parameters()))

xs = xs.cuda()
xs.requires_grad = True

ys = conv(xs)

mb = 1024 * 1024
print('forward, gpu memory:', torch.cuda.memory_allocated()/mb)

print('backward, gpu memory:', torch.cuda.memory_allocated()/mb)
print('backward, gpu memory:', torch.cuda.max_memory_allocated()/mb)


the output should be:

parameters: 20992
backward, gpu memory: 3412.8681640625

but the actual gpu memory comsuming is 6992MiB

I run the code on Tesla K40m with 11441MiB memory.

The following code is for tensorflow:

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
run_metadata = tf.RunMetadata()

xs = tf.random_uniform((400, 11, 36, 1))
kernel = tf.random_uniform((9, 9, 1, 256))
bias = tf.random_uniform((256,))

ys = tf.nn.conv2d(xs, kernel, [1, 1, 1, 1], padding='VALID') + bias

grads = tf.gradients(ys, xs)

sess.run(grads, options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE, output_partition_graphs=True), run_metadata=run_metadata)


print('parameters:', sess.run(tf.size(kernel) + tf.size(bias)))


The gpu memory consuming should be 4424MiB / 11441MiB, which is less than pytorch.

By the way, the parameters for both model should be the same: 20992.

So how to explain that ?


I guess you use cudnn here?
We use the fastest possible cudnn algorithm by default which can consume more memory.
Setting torch.backends.cudnn.deterministic=True will force it to use the default deterministic algorithm which should consume very few memory.
Note that for best speed performance you should use torch.backends.cudnn.benchmark=True that uses cudnn builtin benchmarking tool to choose the best algorithm for your task and gpu.


It’s amazing! Thank you very much!

By the way, how can I find those configs in PyTorch document website ?

They may actually be missing from the doc atm :smiley:
It is mentionned in the new doc here but that’s it.

1 Like

Thank you very much !