Torch using two GPUs with NV link

Hello there,

I am training an RNN seq2seq for NLP with a copynet mechanism (Puduppully 2019). This model increases GPU memory usage really fast, for iterations up to 500 output words in the copy decoder, the model already takes up more than 10 GB of GPU memory. I am not sure whether this is due to a memory leak in the code or not. I have access to two GPUs. Is it possible to use the GPUs with pooled memory (so 40 GB instead of 20) when they are connected via nv link bridge?
This would help training the model without decreasing parameter count drastically.

Any hints where to start searching for memory leaks in a rnn seq2seq decoder? Does the growth in used memory relate to the collected predictions and decoder outputs over the iterations or can it be related to something else?
As rough estimators, I can give the following information:
embed_size / hidden size: 200
bidirectional encoder, batch_size is 1, decoder input is [1, 50, 400].
Target output sequence is a maximum size of 500. This already uses up more than 16 GB of GPU memory with just 1 batch element and 500 iterations.

I appreciate every hint and advice. Thank you in advance!

You cannot directly use both devices as one, but could use e.g. pipeline parallel approaches or FSDP to reduce the memory usage on each device.

If the GPU memory is increasing in each iteration, eventually causing an OOM, you should check if you are appending any tensors to e.g. a list which are still attached to a computation graph.

In the decoder, after each call to the step function, which returns the decoder output and its predictions, I collect these in two lists - but I think this is necessary to keep them as a part of the computation graph in order to compute the gradients and backpropagate correctly, isnt’t it? Should I detach from the graph what the step function puts out? I can’t imagine that this is due to some variable in the step function because these are the same in every iteration.

I just commented out all lines responsible for collecting the decoder step outputs in lists, such that effectively no results are stored, and the memory is still building up at the same pace, being about 11 GB for 500 decoder steps. May I post the code of the decoder? I tried a standard decoder without all the computations needed for copy mechanism and then it works like charm… SO it is somehow caused by the computations in the step function I think.

You shouldn’t detach outputs, if these are used to call .backward() on them. In this case it also won’t be a “memory leak” but expected memory usage.

Thank you very much for your response and help! That’s what I thought - I need the ouputs for backprop. What I am wondering about is that a plain RNN decoder as used in the torch tutorials uses far less memory than the copynet decoder, although the dimensions of the output tensors that I collect are identical. Does that mean that memory consumption depends on the amount of computations and operations executed during the step function, because, I assume, the computation graph is larger? If I post a minimal working example of the two decoders, would you be willing to have a look at it and tell me whether this memory behaviour is correct or whether I am missing something?

I just need to make sure I use as few memory as possible, but my gut feeling tells me the fast increase in memory usage cannot only be dependent on the few computations more.

Yes, you would see an increase in memory usage for each stored forward activation, which is needed to compute the gradients.
E.g. you can take a look at the derivative for relu in tools/autograd/derivatives.yaml which shows that the result is needed for the gradient computation.

While relu does not have any parameters (as it’s a pure functional call) you are seeing the increase in memory usage e.g. using this code snippet:

import torch
import torch.nn as nn

print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 0.00MB allocated

class Model(nn.Module):
    def __init__(self, nb_relus):
        super().__init__()
        self.relus = nn.ModuleList([nn.ReLU() for _ in range(nb_relus)])
        
    def forward(self, x):
        for relu in self.relus:
            x = relu(x)
        return x

# use 10 relus
device = "cuda"
x = torch.randn(1024, 1024, requires_grad=True, device=device)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 4.00MB allocated

model = Model(nb_relus=10)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 4.00MB allocated

out = model(x)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 44.00MB allocated

del x, model, out
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 0.00MB allocated

# use 100 relus
x = torch.randn(1024, 1024, requires_grad=True, device=device)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 4.00MB allocated

model = Model(nb_relus=100)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 4.00MB allocated

out = model(x)
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 404.00MB allocated

out.mean().backward()
print("{:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))
# 12.00MB allocated

Here you can see that each iteration stores the output tensor (4MB) and is thus increasing the memory usage. The backward() call then computes the gradients and deletes the intermediate forward activations. The last statement shows 12MB of allocated memory as x, x.grad, and out will take 4MB of memory each.

So, I tried using the torch.cuda.memory_allocated() function to identify the computations in the step function that consume most memory. between steps E and E2, the memory jumps 20 MB, and between E2 and E3, it jumps another 20 MB.

print("step E {:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))


        all_scores = torch.cat((gen_scores, copy_scores), dim=1)
        probabilities = torch.softmax(all_scores, dim=1)


        gen_probs = probabilities[:,:self.d_gen_vocab]
        copy_probs = probabilities[:,self.d_gen_vocab:]

        enc_input_idx = torch.tensor(linearized_input_records, dtype=torch.long)

        enc_input_idx = enc_input_idx.unsqueeze(2)
        one_hot_template = torch.zeros([batch_size, seq_length*self.d_record, self.d_gen_vocab+(seq_length*self.d_record)]).to(device)

        print("step E2 {:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))

        seq_to_vocab_one_hot = one_hot_template.scatter(2, enc_input_idx, 1)
        seq_to_vocab_one_hot_trimmed = seq_to_vocab_one_hot[:,:, :self.d_gen_vocab]

        copy_probs = copy_probs.unsqueeze(1)
        print("step E3 {:.2f}MB allocated".format(torch.cuda.memory_allocated() / 1024**2))

Is that related to deep copying the tensors? Torch warns me at one point that

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
enc_input_idx = torch.tensor(linearized_input_records, dtype=torch.long)

So could I reduce the amount of memory if I changed implementation here into the proposed warning’s code?