Hi,
When I try to time the following code, the line combining torch.view
and torch.gather
keeps allocating memory in each iteration until OOM. I have no clue why this is happening. If I move the line outside the for loop then it works fine, but this is not what I want as I would like to time everything.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Variables
B, C, H, W = (10, 3, 640, 256)
x = torch.randn((B, C, H, W), device=device)
indices_y_src = torch.randint(H, size=(2, 5), device=device)
indices_x_src = torch.randint(W, size=(2, 1), device=device)
indices_src = (indices_y_src * W + indices_x_src).view(1, 1, -1).expand(B, C, -1)
indices_y_tgt = torch.randint(H, size=(2, 5), device=device)
indices_x_tgt = torch.randint(W, size=(2, 1), device=device)
indices_tgt = (indices_y_tgt * W + indices_x_tgt).view(1, 1, -1).expand(B, C, -1)
weights = torch.nn.Parameter(torch.ones([3, 3, 3, 3])).to(device)
x_out = torch.nn.functional.conv2d(x, weights, padding=1)
times = []
for _ in range(5000):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
values = x_out.view(B, C, -1).gather(-1, indices_src).view(B, C, 2, 5) # This causes OOM
mean_val = torch.mean(values, -1, True).expand(-1, -1, -1, 5).reshape(B, C, -1)
x_out = torch.scatter(x_out.view(B, C, -1), -1, indices_tgt, mean_val).view(B, C, H, W)
end.record()
# Waits for everything to finish running
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print(f"{torch.mean(torch.tensor(times))} +- {torch.std(torch.tensor(times))}")