For loop cause CUDA out of memory

Hi all,

I have a function that uses for loop to modify some value in my tensor. However, after some debugging I found that the for loop actually causes GPU to use a lot of memory. Any idea why is the for loop causes so much memory? Or is there a way to vectorize the troublesome for loop?

Many Thanks

def process_feature_map_2(dm):
    """dm should be a (N,C,D,D) tensor, D is my use case is 14, N is 4, C is 80
    `a` and `b` are (N,1,D,D) tensor
    `c` is same shape as `dm`
    Lets say `dm` is (1,3,2,2) Tesnor and the value of last two dim of `b` is 
    [[0,1],
     [1,2]] and `a` is 
    [[1,2],
     [3,4]]
    This function will create `c` such that it is 
    [[[1,0],
      [0,0]], 
     [[0,2],
      [3,0]],
     [[0,0],
      [0,4]]]
    In plain English, I want to separate the value in `a` into different channels 
    and the channel indexes are stored in `b`
    """
    a = dm.sum(1, keepdim=True)
    b = dm.argmax(1, keepdim=True)
    c = torch.zeros(dm.shape, device=dm.device)
    for n in range(c.shape[0]):
        for i in range(c.shape[1]):
            c[n][i][b[n][0] == i] = a[n][b[n] == I]
    return c

I tried to comment the for loop and just run the following, out of memory also:(
c[0][0][b[0][0] == 0] = a[0][b[0] == 0]

It turns out indexing with bool array requires a lot of memory (https://github.com/pytorch/pytorch/issues/57515), is there a solution for this?

I’m not sure this solves your memory usage issue as I didn’t experience one (perhaps you are doing this operation many times which increases the size of the autograd graph?) but it looks like what you are writing in the for loop is exactly the scatter operation:

import torch

def process_feature_map_2(dm):
    """dm should be a (N,C,D,D) tensor, D is my use case is 14, N is 4, C is 80
    `a` and `b` are (N,1,D,D) tensor
    `c` is same shape as `dm`
    Lets say `dm` is (1,3,2,2) Tesnor and the value of last two dim of `b` is
    [[0,1],
     [1,2]] and `a` is
    [[1,2],
     [3,4]]
    This function will create `c` such that it is
    [[[1,0],
      [0,0]],
     [[0,2],
      [3,0]],
     [[0,0],
      [0,4]]]
    In plain English, I want to separate the value in `a` into different
channels
    and the channel indexes are stored in `b`
    """
    a = dm.sum(1, keepdim=True)
    b = dm.argmax(1, keepdim=True)
    c = torch.zeros(dm.shape, device=dm.device)
    for n in range(c.shape[0]):
        for i in range(c.shape[1]):
            c[n][i][b[n][0] == i] = a[n][b[n] == i]
    return c

def process_feature_map_3(dm):
    a = dm.sum(1, keepdim=True)
    b = dm.argmax(1, keepdim=True)
    c = torch.zeros(dm.shape, device=dm.device)
    return c.scatter(1,b,a)

inp = torch.randn(4, 80, 14, 14)
out = process_feature_map_2(inp)
out2 = process_feature_map_3(inp)
print(torch.equal(out, out2))
1 Like

Hi, thanks for your answer. Just a follow-up question. Is this function not differentiable due to argmax() is involved here? Although no error pop out during training, I still think the process of dm -> argmax() -> b and use b in scatter() will have some problems.

I believe that depends on how argmax() is being used. For example, if it it used in an indexing operation as appears to be the case here, the model itself should be backprop-able because no gradients flow through the argmax itself. For example, consider a trival model sketch that looks like:

a #some tensor
b #another tensor
c #some index from some calculation, can be replaced with torch.argmax(a, ...)
loss = abs(a[c] + b[c] - target)
loss.backward()

Here, no gradients actually flow through c even though it is used in the model to describe the computation because it is not used in the computation itself.
On the other hand, if the computation looks like:

a #some tensor
b #another tensor
c #some index from some calculation, can be replaced with torch.argmax(a, ...)
loss = abs(a.sum() + b.sum() + c - target)
loss.backward()

Then we would indeed have an issue as we cannot backprop through c itself.