Multitask Learning Autograd Out of Memory

I’m working on an text attribute classification problem. There are 42 attributes and each one is a binary prediction between [0, 1], and I train with BCE (Binary cross-entroy loss).

[03/21/2018 10:49:51 PM] INFO: Model(
  (drop): Dropout(p=0.2)
  (encoder): LSTM(100, 512, dropout=0.2)
  (out): Linear(in_features=512, out_features=42, bias=True)
  (embed): Embedding(108926, 100)
  (normalize): Softmax()
num of parameters: 12193634
THCudaCheck FAIL file=/pytorch/torch/lib/THC/generic/ line=58 error=2 : out of memory

The final representation after LSTM is H: (batch-size, sequence-length, hidden-size).
I want to simply apply a multi-task attention to this tensor, meaning I want a task query q_t: (hidden_size), and there are 42 of these queries.

So in my code I wrote:

task_queries = nn.Parameter(torch.randn(hidden_size, 42))
keys = torch.matmul(H, task_queries)
task_specific_list = []
for t_n in xrange(42):
    torch.sum(H * keys[:, :, t_n], 0)  # applying 42 attention mixes over H
task_specific_mix = torch.stack(task_specific_list, dim=1)

It seems that by doing this, PyTorch warns me of an OOM error during back-prop. I wonder if this is caused by the fact that PyTorch is creating a gradient tensor of (batch-size, sequence-length, hidden-dim, num_tasks), which since my text is very long (~900), this tensor can be as large as 500 Million (32 x 900 x 512 x 42).

Is this the problem?? Or maybe this isn’t the case? The parameter increase of this “multi-task attention” is almost negligible, and I really don’t know what caused OOM. It would also help if I know which line in forward passing is causing the issue…any help/suggestions?