How to reduce the GPU memory in the case of using `torch.cat`

In forward function, I have the operation that using a vector to multiply a matrix in an element-wise way along the channel dimension. The vector’s channel dim is 1, so in forward I use torch.cat to inflate it to match the matrix’s channel dim 64. But in this method I found it will cost lots of GPU memory when training. How to change the code to reduce the memory and support the DataParallel at the same time?

batch_size = 2
x = Variable(torch.randn(batch_size, 2, 14).cuda()) # vector
y = Variable(torch.randn(batch_size, 64, 2, 14).cuda()) # matrix, channel dim=64

# this will use lots of GPU memory
x = torch.cat([x.view(batch_size, 1, 2, 14) for i in range(64)], dim=1)

y = y.mul(x)

Thanks.

x.expand(batch_size, 64, 2, 14) http://pytorch.org/docs/master/tensors.html#torch.Tensor.expand

1 Like