F.grid_sample extremely slow

I found that F.grid_sample in my code is extremely slow, for example, the following block takes about 0.9s on GPU with Pytorch 1.6.0

B, H, W, D, C = coords.size()
coords = coords.view(B, -1, C).unsqueeze(1)  # (B, 1, N, 3), N = 262,144
# yz_plane with size of (1, 32, 256, 256)
sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)  # B N 32

However, if I add two more F.grid_sample, for example, the following block STILL takes about 0.9s on GPU. I wonder why in my case F.grid_sample is so slow, am I missing something since the other two F.grid_sample barely cost time.

B, H, W, D, C = coords.size()
coords = coords.view(B, -1, C).unsqueeze(1)  # (B, 1, N, 3), N = 262,144
# yz_plane, xz_plane, xy_plane with size of (1, 32, 256, 256)
sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)  # B N 32
sample_xz_feat = F.grid_sample(xz_plane.repeat(B, 1, 1, 1), coords[..., [0, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)
sample_xy_feat = F.grid_sample(xy_plane.repeat(B, 1, 1, 1), coords[..., [0, 1]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)

Updated:

This is how I compute the time

import time

coords = torch.randn((1, 64, 64, 64, 3)).cuda()
yz_plane = torch.randn((1, 32, 256, 256)).cuda()
xy_plane = torch.randn((1, 32, 256, 256)).cuda()
xz_plane = torch.randn((1, 32, 256, 256)).cuda()
start_time = time.time()
# above code block
print(time.time() - start_time)

It seems benchmark with torch.utils.benchmark show it requires honest linear time.

<torch.utils.benchmark.utils.common.Measurement object at 0x7f02998ded00>
example(coords, yz_plane)
setup: from __main__ import example
  8.28 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f02998dee50>
example(coords, yz_plane)
setup: from __main__ import example
  5.80 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f02998ded00>
example2(coords, yz_plane, xy_plane, xz_plane)
setup: from __main__ import example2
  19.76 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f02998dee50>
example2(coords, yz_plane, xy_plane, xz_plane)
setup: from __main__ import example2
  21.34 ms
  1 measurement, 100 runs , 1 thread

import torch
import torch.nn.functional as F

coords = torch.randn((1, 512, 512, 2, 3)).cuda()
yz_plane = torch.randn((1, 32, 256, 256)).cuda()
xy_plane = torch.randn((1, 32, 256, 256)).cuda()
xz_plane = torch.randn((1, 32, 256, 256)).cuda()

def example(coords, yz_plane):
    B, H, W, D, C = coords.size()
    coords = coords.view(B, -1, C).unsqueeze(1)  # (B, 1, N, 3), N = 262,144
    # yz_plane with size of (1, 32, 256, 256)
    sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)  # B N 32
    return sample_yz_feat

def example2(coords, yz_plane, xz_plane, xy_plane):
    B, H, W, D, C = coords.size()
    coords = coords.view(B, -1, C).unsqueeze(1)  # (B, 1, N, 3), N = 262,144
    # yz_plane, xz_plane, xy_plane with size of (1, 32, 256, 256)
    sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)  # B N 32
    sample_xz_feat = F.grid_sample(xz_plane.repeat(B, 1, 1, 1), coords[..., [0, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)
    sample_xy_feat = F.grid_sample(xy_plane.repeat(B, 1, 1, 1), coords[..., [0, 1]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)
    return sample_yz_feat


import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='example(coords, yz_plane)',
    setup='from __main__ import example',
    globals={'coords': coords, 'yz_plane': yz_plane})

t1 = benchmark.Timer(
    stmt='example2(coords, yz_plane, xy_plane, xz_plane)',
    setup='from __main__ import example2',
    globals={'coords': coords, 'yz_plane': yz_plane, 'xy_plane': xy_plane, 'xz_plane': xz_plane})

print(t0.timeit(100))
print(t0.timeit(100))
print(t1.timeit(100))
print(t1.timeit(100))

sorry for being unclear, N=262144

I updated but it’s same.
Can you reproduce the problem like provided code and log?

(Updated)
I’ve also experienced similar problem on my complex training code.
But the problem is resolved at some time while large chunk modification without clarifying the exact problem.
So I’m interested in your case to clarify the problem.

It seems like testing F.grid_sample individually works fine, however, it is quite slow when I test it on my training code, so weird. Let me see if I can find out. Thanks anyway

1 Like