Model training on GPU takes up all CPU RAM while training

Hi,

I am attempting to train a model on a Google Cloud compute instance with access to a GPU and a CPU with ~26 GB of RAM. However, as my model’s training loop progresses it slowly takes up all of the RAM I have available on my machine. My first thought was that for some reason the computation graph was not being freed during loss.backward(), but the issue persists even when I run my loop with torch.no_grad(). I’ve copied my training loop below.

@profile
def train_loop(epoch, model, loader):
     model.train()
     i = 0
     results = {'loss': 0, 'counter': 0, 'loss_arr':[]}
     for inputs, targets, _ , _, _, _, max_val in loader:
        optimizer.zero_grad()
        pred = model(inputs.to(device)[:,None,:,:])
        loss = criterion(pred, targets.to(device)[:,None,:,:], torch.Tensor(max_val).to(device))

        loss.backward()
        optimizer.step()

        results['loss'] += loss.item() * len(inputs)
        results['counter'] += len(inputs)
        results['loss_arr'].append(loss.item())
        if i == 0:
            break
        if i % log_every == 0:
             print("Train: Epoch: %d \t Iteration: %d \t loss: %.4f" % (epoch, i, sum(results['loss_arr'])/len(results['loss_arr'])))
             results['loss_arr'] = []
        i +=1

    scheduler.step()
    return 

I am using memory profiler to see which lines are contributing to the memory leak, and I get the following table:

Line # Mem usage Increment Occurrences Line Contents

53  637.031 MiB  637.031 MiB           1   @profile
54                                         def train_loop(epoch, model, loader):
55  637.031 MiB    0.000 MiB           1       model.train()
56  637.031 MiB    0.000 MiB           1       i = 0
57  637.031 MiB    0.000 MiB           1       results = {'loss': 0, 'counter': 0, 'loss_arr':[]}
58  672.133 MiB   35.102 MiB           1       for inputs, targets, _ , _, _, _, max_val in loader:
59                                                 
60  672.434 MiB    0.301 MiB           1           optimizer.zero_grad()
61  845.789 MiB  845.789 MiB           1           pred = model(inputs.to(device)[:,None,:,:])
62  884.059 MiB   38.270 MiB           1           loss = criterion(pred, targets.to(device)[:,None,:,:], torch.Tensor(max_val).to(device))
63                                         
64  953.922 MiB   69.863 MiB           1           loss.backward()
65  989.824 MiB   35.902 MiB           1           optimizer.step()
66                                         
67  989.824 MiB    0.000 MiB           1           results['loss'] += loss.item() * len(inputs)
68  989.824 MiB    0.000 MiB           1           results['counter'] += len(inputs)
69  989.824 MiB    0.000 MiB           1           results['loss_arr'].append(loss.item())
70  989.824 MiB    0.000 MiB           1           if i == 0:
71  989.824 MiB    0.000 MiB           1               break
72                                                 if i % log_every == 0:
73                                                     print("Train: Epoch: %d \t Iteration: %d \t loss: %.4f" % (epoch, i, sum(results['loss_arr'])/len(results['loss_arr'])))
74                                                     results['loss_arr'] = []
75                                                 i +=1
76                                         
77  989.824 MiB    0.000 MiB           1       scheduler.step()
78  989.824 MiB    0.000 MiB           1       return 

I then began placing @profile decorators at each function call and, following the breadcrumbs, this led me to the following function in one of my model’s building blocks:

Line # Mem usage Increment Occurrences Line Contents

325  830.934 MiB  776.305 MiB           8       @profile
326                                             def forward(self, x, h, w):
327  830.934 MiB    0.000 MiB           8           res = x
328                                         
329  830.934 MiB    0.000 MiB           8           x = self.norm1(x)
330  830.934 MiB 6594.449 MiB           8           attn = self.attn(x, h, w)
331  830.934 MiB   12.539 MiB           8           x = res + self.drop_path(attn)
332  830.934 MiB    2.863 MiB           8           x = x + self.drop_path(self.mlp(self.norm2(x)))
333                                         
334  830.934 MiB    0.000 MiB           8           return x

As you can see, a huge amount of memory is allocated at line 330. However, the trail goes cold when I check the memory usage of the function called by self.attn, as the following table shows the following:

Line # Mem usage Increment Occurrences Line Contents

190  830.934 MiB 6555.223 MiB           8       @profile
191                                             def forward(self, x, h, w):
192                                                 if (
193  830.934 MiB    0.000 MiB           8               self.conv_proj_q is not None
194                                                     or self.conv_proj_k is not None
195                                                     or self.conv_proj_v is not None
196                                                 ):
197  830.934 MiB   27.496 MiB           8               q, k, v = self.forward_conv(x, h, w)
198                                         
199  830.934 MiB    0.000 MiB           8           q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
200  830.934 MiB    0.000 MiB           8           k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
201  830.934 MiB    0.000 MiB           8           v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
202                                         
203  830.934 MiB    3.031 MiB           8           attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
204  830.934 MiB    8.699 MiB           8           attn = self.softmax(attn_score)
205                                                 #attn = F.softmax(attn_score, dim=-1)
206  830.934 MiB    0.000 MiB           8           attn = self.attn_drop(attn)
207                                         
208  830.934 MiB    0.000 MiB           8           x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
209  830.934 MiB    0.000 MiB           8           x = rearrange(x, 'b h t d -> b t (h d)')
210                                         
211  830.934 MiB    0.000 MiB           8           x = self.proj(x)
212  830.934 MiB    0.000 MiB           8           x = self.proj_drop(x)
213  830.934 MiB    0.000 MiB           8           return x

I am at a loss for why this issue is occurring, so any help would be greatly appreciated. In case it is helpful, my model is built using building blocks found in the following repo: GitHub - microsoft/CvT: This is an official implementation of CvT: Introducing Convolutions to Vision Transformers.
Lastly, this is my first ever post here, so I apologize if my post does not follow the etiquette of the forums.
Thank you