Hi, I notice that nn.Conv2d’s backward pass is around 2~3 times faster than the nn.Conv3d’s. Why it is like that?
The condition of my experiments is below:
- nn.Conv2d with a (3, 3) kernel process a tensor in shape [Batch * Time, in_channel, Height, Width].
- nn.Conv3d with a(1, 3, 3) kernel process a tensor in shape[Batch, in_channel, Time, Height, Width].
So the number of elments for conv2d and conv3d input are equal.
The forward pass are nearly the same, but got a different speed for the backward pass:
backward time conv3d: 0.10284185409545898
backward time without reshape conv2d: 0.05658698081970215
My full script is like below:
import time
import torch
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from torch.nn import Conv2d, Conv3d
# suppose input tensor is in shape [B * Time, C, H, W]
inp_ts_2d = torch.ones(4000, 2, 64, 64, dtype=torch.float, device="cuda:0")
# suppose input tensor is in shape [B, C, Time, H, W]
inp_ts_3d = torch.ones(8, 2, 500, 64, 64, dtype=torch.float, device="cuda:0")
# conv2d layer init
conv2d = Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv2d.to("cuda:0")
# conv3d layer init, we must set kernel_size first dim as 1 to avoid temporal convolution
conv3d = Conv3d(in_channels=2, out_channels=2, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=False)
conv3d.to("cuda:0")
# make sure conv2d and conv3d has same weights
conv3d.weight.data = conv2d.weight.data.unsqueeze(2)
# conv2d forward warm-up
start_t_2d = time.time()
out_2d = conv2d(inp_ts_2d)
print(f"warm-up execution time conv2d: {time.time() - start_t_2d}")
# conv3d forward warm-up
start_t_3d = time.time()
out_3d = conv3d(inp_ts_3d)
print(f"warm-up execution time conv3d: {time.time() - start_t_3d}")
conv3d.zero_grad()
conv2d.zero_grad()
# conv3d forward
start_t_3d = time.time()
out_3d = conv3d(inp_ts_3d)
print(f"forward time conv3d: {time.time() - start_t_3d}")
# conv2d forward
start_t_2d = time.time()
out_2d = conv2d(inp_ts_2d)
print(f"forward time conv2d: {time.time() - start_t_2d}")
print("Switch execution order")
# conv2d forward
start_t_2d = time.time()
out_2d = conv2d(inp_ts_2d)
print(f"forward time conv2d: {time.time() - start_t_2d}")
# conv3d forward
start_t_3d = time.time()
out_3d = conv3d(inp_ts_3d)
print(f"forward time conv3d: {time.time() - start_t_3d}")
# check whether output from conv2d and conv3d are equal
out_2d_reshape = out_2d.reshape(8, 500, 2, 64, 64).movedim(1, 2)
equal = out_2d_reshape == out_3d
# unique return True only
print(torch.unique(equal))
# 3d backward warm-up
start_t_3d_bp = time.time()
out_3d.sum().backward()
print(f"warm-up backward time conv3d: {time.time() - start_t_3d_bp}")
# 2d backward warm-up
start_t_2d_bp = time.time()
out_2d.sum().backward()
print(f"warm-up backward time conv2d: {time.time() - start_t_2d_bp}")
# zero grad and apply forward again
conv2d.zero_grad()
out_2d = conv2d(inp_ts_2d)
conv3d.zero_grad()
out_3d = conv3d(inp_ts_3d)
# 3d backward
start_t_3d_bp = time.time()
out_3d.sum().backward()
print(f"backward time conv3d: {time.time() - start_t_3d_bp}")
# 2d backward without reshape
start_t_2d_bp = time.time()
out_2d.sum().backward()
print(f"backward time without reshape conv2d: {time.time() - start_t_2d_bp}")
# zero grad and apply forward again
conv2d.zero_grad()
out_2d = conv2d(inp_ts_2d).reshape(8, 500, 2, 64, 64)
# 2d backward with reshape output tensor
start_t_2d_bp = time.time()
out_2d.sum().backward()
print(f"backward time with reshape conv2d: {time.time() - start_t_2d_bp}")
print("Switch backward execution order")
# zero grad and apply forward again
conv2d.zero_grad()
out_2d = conv2d(inp_ts_2d)
conv3d.zero_grad()
out_3d = conv3d(inp_ts_3d)
# 2d backward without reshape
start_t_2d_bp = time.time()
out_2d.sum().backward()
print(f"backward time without reshape conv2d: {time.time() - start_t_2d_bp}")
# 3d backward
start_t_3d_bp = time.time()
out_3d.sum().backward()
print(f"backward time conv3d: {time.time() - start_t_3d_bp}")