I found that F.grid_sample
in my code is extremely slow, for example, the following block takes about 0.9s
on GPU with Pytorch 1.6.0
B, H, W, D, C = coords.size()
coords = coords.view(B, -1, C).unsqueeze(1) # (B, 1, N, 3), N = 262,144
# yz_plane with size of (1, 32, 256, 256)
sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1) # B N 32
However, if I add two more F.grid_sample
, for example, the following block STILL takes about 0.9s
on GPU. I wonder why in my case F.grid_sample
is so slow, am I missing something since the other two F.grid_sample
barely cost time.
B, H, W, D, C = coords.size()
coords = coords.view(B, -1, C).unsqueeze(1) # (B, 1, N, 3), N = 262,144
# yz_plane, xz_plane, xy_plane with size of (1, 32, 256, 256)
sample_yz_feat = F.grid_sample(yz_plane.repeat(B, 1, 1, 1), coords[..., [1, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1) # B N 32
sample_xz_feat = F.grid_sample(xz_plane.repeat(B, 1, 1, 1), coords[..., [0, 2]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)
sample_xy_feat = F.grid_sample(xy_plane.repeat(B, 1, 1, 1), coords[..., [0, 1]].flip(-1), align_corners=True).permute(0, 2, 3, 1).squeeze(1)
Updated:
This is how I compute the time
import time
coords = torch.randn((1, 64, 64, 64, 3)).cuda()
yz_plane = torch.randn((1, 32, 256, 256)).cuda()
xy_plane = torch.randn((1, 32, 256, 256)).cuda()
xz_plane = torch.randn((1, 32, 256, 256)).cuda()
start_time = time.time()
# above code block
print(time.time() - start_time)