In forward
function, I have the operation that using a vector to multiply a matrix in an element-wise way along the channel
dimension. The vector’s channel dim is 1
, so in forward I use torch.cat
to inflate it to match the matrix’s channel
dim 64
. But in this method I found it will cost lots of GPU memory when training. How to change the code to reduce the memory and support the DataParallel
at the same time?
batch_size = 2
x = Variable(torch.randn(batch_size, 2, 14).cuda()) # vector
y = Variable(torch.randn(batch_size, 64, 2, 14).cuda()) # matrix, channel dim=64
# this will use lots of GPU memory
x = torch.cat([x.view(batch_size, 1, 2, 14) for i in range(64)], dim=1)
y = y.mul(x)
Thanks.