How to reduce the GPU memory in the case of using ``

In forward function, I have the operation that using a vector to multiply a matrix in an element-wise way along the channel dimension. The vector’s channel dim is 1, so in forward I use to inflate it to match the matrix’s channel dim 64. But in this method I found it will cost lots of GPU memory when training. How to change the code to reduce the memory and support the DataParallel at the same time?

batch_size = 2
x = Variable(torch.randn(batch_size, 2, 14).cuda()) # vector
y = Variable(torch.randn(batch_size, 64, 2, 14).cuda()) # matrix, channel dim=64

# this will use lots of GPU memory
x =[x.view(batch_size, 1, 2, 14) for i in range(64)], dim=1)

y = y.mul(x)


x.expand(batch_size, 64, 2, 14)

1 Like