I have two models, one with 20K parameters and one with 1900K. They have the same inputs and outputs, and both follow a similar pattern of downscaling 4 times followed by upscaling 2 times (with conv2d’s inbetween), so I would expect that the 20K model runs much faster than the 1900K one. This is not the case however.
The 20K model is 3-8x slower to evaluate/train than the 1900K model (single gpu, single thread, batch size 1). How can this be? Am I doing something wrong?
Using the Tensorwatch library I calculated that:
Model 1 (CropUNet) = 1900K params, 19.4B madd, 9.4B flops
Model 2 (EncDecMil18) = 20K params, 3.95B madd, 2.0B flops
Model 2 has way more convolutions than Model 1, but they are significantly smaller and are doing less work than the convolutions of Model 1, as indicated by the significantly lower madd/flops.
The timings, produced by the code at the end of this question is:
Model 1
Get data 0.002ms
Data to device 10.888ms
Zero grad 0.254ms
Model eval 2.326ms
Criterion 0.082ms
Backward 2.280ms
Optim step 3.855ms
Model 2
Get data 0.003ms
Data to device 0.462ms
Zero grad 1.074ms
Model eval 9.370ms
Criterion 0.128ms
Backward 9.215ms
Optim step 48.186ms
I am running Torch 1.4.0a0+7f73f1d, cuDNN 7.6.5, CUDA 10.1
Code:
#!/usr/bin/env python3
# Sandbox
# Imports
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from timeit import default_timer
#
# UNet
#
class CropUNet(nn.Module):
def __init__(self):
super().__init__()
D = 16
self.input_block = nn.Sequential(
self.conv_block(chin=3, chout=D),
self.conv_block(chin=D, chout=D),
)
self.down1 = self.down_block(chin=D, chout=2*D)
self.down2 = self.down_block(chin=2*D, chout=4*D)
self.down3 = self.down_block(chin=4*D, chout=8*D)
self.down4 = self.down_block(chin=8*D, chout=16*D)
self.up4 = self.trans_conv(chin=16*D, chout=8*D)
self.up3 = self.up_block(chin=16*D, chout=4*D)
self.output_block = nn.Sequential(
self.conv_block(chin=8*D, chout=4*D),
self.conv_block(chin=4*D, chout=4*D),
nn.Conv2d(in_channels=4*D, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True),
)
def forward(self, batch_in):
tmpfeat = self.input_block(batch_in)
feat1 = self.down1(tmpfeat)
feat2 = self.down2(feat1)
feat3 = self.down3(feat2)
feat4 = self.down4(feat3)
tmpfeat = torch.cat((feat3, self.up4(feat4)), dim=1)
tmpfeat = torch.cat((feat2, self.up3(tmpfeat)), dim=1)
tmpfeat = self.output_block(tmpfeat)
return tmpfeat
def conv_block(self, chin, chout):
return nn.Sequential(
nn.Conv2d(in_channels=chin, out_channels=chout, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(chout),
nn.ReLU(inplace=True),
)
def trans_conv(self, chin, chout):
return nn.ConvTranspose2d(in_channels=chin, out_channels=chout, kernel_size=2, stride=2, bias=True)
def down_block(self, chin, chout):
return nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
self.conv_block(chin=chin, chout=chout),
self.conv_block(chin=chout, chout=chout),
)
def up_block(self, chin, chout):
dchout = 2*chout
return nn.Sequential(
self.conv_block(chin=chin, chout=dchout),
self.conv_block(chin=dchout, chout=dchout),
self.trans_conv(chin=dchout, chout=chout),
)
#
# EncDecMil18
#
class ResConvBlock(nn.Module):
def __init__(self, ch):
super().__init__()
chneck = ch // 2
self.forward_path = nn.Sequential(
nn.Conv2d(in_channels=ch, out_channels=chneck, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(chneck),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=chneck, out_channels=chneck, kernel_size=(5, 1), stride=1, padding=(2, 0), bias=False),
nn.Conv2d(in_channels=chneck, out_channels=chneck, kernel_size=(1, 5), stride=1, padding=(0, 2), bias=False),
nn.BatchNorm2d(chneck),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=chneck, out_channels=ch, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ch),
)
self.output = nn.ReLU(inplace=True)
def forward(self, batch_in):
return self.output(batch_in + self.forward_path(batch_in))
class DownBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv_chain = nn.Sequential(
ResConvBlock(ch=ch),
ResConvBlock(ch=ch),
ResConvBlock(ch=ch),
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
def forward(self, batch_in):
return self.pool(self.conv_chain(batch_in))
class UpBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self.conv_chain = nn.Sequential(
ResConvBlock(ch=ch),
ResConvBlock(ch=ch),
ResConvBlock(ch=ch),
)
def forward(self, batch_in, indices):
return self.conv_chain(self.unpool(batch_in, indices))
class EncDecMil18(nn.Module):
def __init__(self):
super().__init__()
D = 16
self.input_block = self.conv_block(chin=3, chout=D, kernel_size=7)
self.down1 = DownBlock(ch=D)
self.down2 = DownBlock(ch=D)
self.down3 = DownBlock(ch=D)
self.down4 = DownBlock(ch=D)
self.up4 = UpBlock(ch=D)
self.up3 = UpBlock(ch=D)
self.output_block = nn.Conv2d(in_channels=D, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, batch_in):
tmpfeat = self.input_block(batch_in)
tmpfeat, indices1 = self.down1(tmpfeat)
tmpfeat, indices2 = self.down2(tmpfeat)
tmpfeat, indices3 = self.down3(tmpfeat)
tmpfeat, indices4 = self.down4(tmpfeat)
tmpfeat = self.up4(tmpfeat, indices4)
tmpfeat = self.up3(tmpfeat, indices3)
tmpfeat = self.output_block(tmpfeat)
return tmpfeat
def conv_block(self, chin, chout, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels=chin, out_channels=chout, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1)//2, bias=False),
nn.BatchNorm2d(chout),
nn.ReLU(inplace=True),
)
#
# Criterion/loss
#
class CropUNetLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, output, target):
# The total loss is the average over the sample losses in the batch, where for each individual sample the loss is computed as
# the ratio of its calculated MSE to the MSE that it would have if every pixel had an error of exactly 0.02
if output.ndim > 3:
output = output.squeeze(dim=1)
return F.mse_loss(output, target, reduction='mean') / (0.02 * 0.02)
#
# Train
#
def train(model):
print(f"TRAINING {model.__class__.__name__}")
print()
print(model)
print()
print(f"Model parameters:")
num_total = sum(p.numel() for p in model.parameters())
num_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_untrain = num_total - num_train
print(f" Num untrainable = {num_untrain:,}")
print(f" Num trainable = {num_train:,}")
print(f" Num total = {num_total:,}")
print()
device = torch.device('cuda')
model.to(device)
data_loader = [(torch.rand(1, 3, 480, 640) - 0.5, torch.rand(1, 120, 160) - 0.5) for i in range(20)]
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, betas=(0.9, 0.999), amsgrad=False)
criterion = CropUNetLoss()
t = default_timer()
for data, target in data_loader:
print()
print(f"Get data {1000*(default_timer() - t):.3f}ms"); t = default_timer()
data = data.to(device)
target = target.to(device)
print(f"Data to device {1000*(default_timer() - t):.3f}ms"); t = default_timer()
optimizer.zero_grad()
print(f"Zero grad {1000*(default_timer() - t):.3f}ms"); t = default_timer()
output = model(data)
print(f"Model eval {1000*(default_timer() - t):.3f}ms"); t = default_timer()
loss = criterion(output, target)
print(f"Criterion {1000*(default_timer() - t):.3f}ms"); t = default_timer()
loss.backward()
print(f"Backward {1000*(default_timer() - t):.3f}ms"); t = default_timer()
optimizer.step()
print(f"Optim step {1000*(default_timer() - t):.3f}ms"); t = default_timer()
print()
#
# Main
#
# Main function
def main(argv):
train(CropUNet())
train(EncDecMil18())
# Run main function
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
# EOF