Forward function out of memory with torch.cat()

Here is my code description:
My forward function would take “u”, “v” and “feature” and “hidden” as input parameters. Here u, v and feature are embeddings. Suppose that “u” and “v” are 1D tensor with size(1, 300), “feature” is a 2D tensor with size(2000, 300). I want to firstly expand u and v to (2000, 300) and then concatenate with feature.

expand_u = u.expand(feature.size(0), u.size(1))
expand_v = v.expand(feature.size(0), v.size(1))
fuv_combine = torch.cat((feature_embed, expand_u), 1)
fuv_combine = torch.cat((fuv_combine, expand_v), 1)

Then I have a GRU and run a for loop to take input and hidden for each step to do the calculation. In the for loop, I need to do MLP(torch.cat((fuv_combine, h_prev), 1) where h_prev stands for the hidden state from previous step. Since hidden state h_prev is in shape (1, 300), I also need to expand it to (2000, 300).

The problem is when the program comes to this part, it would get Out Of Memory (OOM) error. The traceback information mostly shows that it happens when the program tries to do the MLP(torch.cat((fuv_combine, h_prev), 1) part. I don’t know how can I get avoid of this issue.

Hi, I come up with the same question, have you got a solution?