I created a Conv2d layer that uses unfolding followed by an MVM. I then combine it with a BatchNorm operation in a Sequential model. I do the same but this time with a normal Conv2d layer. I then profile both and compare the outputs.
I see that the batch norm call aten::batch_norm
takes 3.5x longer with the unfolded convolution. I put everything on Cuda. Here is a small snippet to reproduce:
Why am I seeing this slow-down in the batch norm? Is there some fusing going on internally?
import torch
class ConvUnfold(torch.nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
bias,
device,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
bias=bias,
device=device,
)
self.linear_weight = self.weight.reshape(shape=(out_channels, in_channels*kernel_size**2))
self.linear_weight = self.linear_weight.to(device)
def _mvm(self, input):
return input @ self.linear_weight.T
def _forward_unfold(self, x_input):
im_shape = x_input.shape
x_input_ = torch.nn.functional.unfold(x_input, kernel_size=self.kernel_size, dilation=self.dilation,
padding=self.padding, stride=self.stride).transpose(1, 2)
out = self._mvm(x_input_).transpose(1, 2)
out_size = (im_shape[2] + 2 * self.padding[0]
- self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
return out.view(im_shape[0], self.out_channels, out_size, -1)
def forward(self, input):
out = self._forward_unfold(input)
return out
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input = torch.randn(size=(128,3,32,32)).to(device)
unf = torch.nn.Sequential(ConvUnfold(3, 16, 3, bias=False, device=device), torch.nn.BatchNorm2d(16))
conv = torch.nn.Sequential(torch.nn.Conv2d(3, 16, 3, bias=False), torch.nn.BatchNorm2d(16))
unf.eval(); conv.eval()
unf.to(device); conv.to(device)
conv.load_state_dict(unf.state_dict())
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof_unf:
with record_function("model_inference"):
unf(input)
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof_conv:
with record_function("model_inference"):
conv(input)
print(prof_unf.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(prof_conv.key_averages().table(sort_by="cuda_time_total", row_limit=10))