Combining torch.view and torch.gather throws OOM

Hi,

When I try to time the following code, the line combining torch.view and torch.gather keeps allocating memory in each iteration until OOM. I have no clue why this is happening. If I move the line outside the for loop then it works fine, but this is not what I want as I would like to time everything.

device = "cuda" if torch.cuda.is_available() else "cpu"
# Variables
B, C, H, W = (10, 3, 640, 256)
x = torch.randn((B, C, H, W), device=device)
indices_y_src = torch.randint(H, size=(2, 5), device=device)
indices_x_src = torch.randint(W, size=(2, 1), device=device)
indices_src = (indices_y_src * W + indices_x_src).view(1, 1, -1).expand(B, C, -1)
indices_y_tgt = torch.randint(H, size=(2, 5), device=device)
indices_x_tgt = torch.randint(W, size=(2, 1), device=device)
indices_tgt = (indices_y_tgt * W + indices_x_tgt).view(1, 1, -1).expand(B, C, -1)
weights = torch.nn.Parameter(torch.ones([3, 3, 3, 3])).to(device)

x_out = torch.nn.functional.conv2d(x, weights, padding=1)
times = []
for _ in range(5000):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    values = x_out.view(B, C, -1).gather(-1, indices_src).view(B, C, 2, 5)  # This causes OOM
    mean_val = torch.mean(values, -1, True).expand(-1, -1, -1, 5).reshape(B, C, -1)
    x_out = torch.scatter(x_out.view(B, C, -1), -1, indices_tgt, mean_val).view(B, C, H, W)
    end.record()
    # Waits for everything to finish running
    torch.cuda.synchronize()
    times.append(start.elapsed_time(end))
print(f"{torch.mean(torch.tensor(times))} +- {torch.std(torch.tensor(times))}")

x_out is created in a differentiable way as the output of F.conv2d, which means it’s attached to a computation graph. Inside the loop you are also using differentiable operations to recompute x_out, which will increase the computation graph and thus also the memory usage.
You could avoid it by detaching the newly computed x_out via:

x_out = torch.scatter(x_out.view(B, C, -1), -1, indices_tgt, mean_val).view(B, C, H, W).detach()

which shows a constant memory usage, but will disallow to compute the gradients through all iterations back to weights.

Hey,

Thanks for the answer. What I am doing is equivalent in result to this other example. However this is not throwing OOM. Is it because I am modifying x_out in place?


for _ in range(5000):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    values = x_out[:, :, indices_y_src, indices_x_src]
    mean_val = torch.mean(values, -1, True)
    x_out[:, :, indices_y_tgt, indices_x_tgt] = mean_val
    end.record()
    # Waits for everything to finish running
    torch.cuda.synchronize()
    times.append(start.elapsed_time(end))
print(f"{torch.mean(torch.tensor(times))} +- {torch.std(torch.tensor(times))}")

This might be the case but are you seeing any “disallowed inplace operations” during the backward pass?

No. For example:


for _ in range(20):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    values = x_out[:, :, indices_y_src, indices_x_src]
    mean_val = torch.mean(values, -1, True)
    x_out[:, :, indices_y_tgt, indices_x_tgt] = mean_val
    loss = torch.nn.MSELoss()(x_out, x)
    loss.backward(retain_graph=True)
    end.record()
    # Waits for everything to finish running
    torch.cuda.synchronize()
    times.append(start.elapsed_time(end))
print(f"{torch.mean(torch.tensor(times))} +- {torch.std(torch.tensor(times))}")

Correct me if I am wrong but I think it is because in values = x_out[:, :, indices_y_src, indices_x_src] the RHS is implicitly cloning those memory addresses into a new contiguous location and then assigning that to LHS. Any inplace change of x_out happening afterwards does not affect the computation graph. If I change that line with values = x_out.view(B, C, -1).gather(-1, indices_src).view(B, C, 2, 5) then it throws RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation as I believe that gather does not internally clone anything.

I am doing all the intricate gather/scatter/view operations just to avoid the aten::index_put_ happening in x_out[:, :, indices_y_tgt, indices_x_tgt] = mean_val. The backward pass for that operation is really slow compared with gather/scatter. I am using this “trick” in a bigger model and it does speed it up by 1.08x. Do you know if there is a workaround of bypassing aten::index_put_ which do not require gather/scatter?

Could you post the shapes you have used to profile gather/scatter vs. index_put_, please?

You mean the shapes of the tensors I am working with?

x_out.shape = (10, 3, 640, 256)
indices_x_src.shape = indices_x_tgt.shape = (2, 1)
indices_x_tgt.shape = indices_y_tgt.shape = (2, 5)
indices_src.shape = indices_tgt.shape = (10, 3, 10)

With indexing (index_put_ in backward) I get an E2E time of 61.06 ms
With gather/scatter I get an E2E time of 49.21 ms

@ptrblck Any updates on this?

No, since I didn’t spend time writing a minimal and executable code snippet as a reference for profiling.