I have read the documentation regarding autocast and how to use it correctly. I will share my model so that you can see if there are some layers that do not respect the autocast because the speed performance remains the same.
import torch
class DCUNet(torch.nn.Module):
def __init__(self):
super(DCUNet, self).__init__()
self.dcBlock1 = DCBlock(in_channels=1, out_channels=32, kernel_size=3, stride=1, apply_batchnorm=True)
self.resPath1 = ResPath(in_channels=32, out_channels=32, residuals=4, apply_batchnorm=True)
self.maxPool1 = torch.nn.MaxPool2d(kernel_size=(2,2))
self.dcBlock2 = DCBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, apply_batchnorm=True)
self.resPath2 = ResPath(in_channels=64, out_channels=64, residuals=3, apply_batchnorm=True)
self.maxPool2 = torch.nn.MaxPool2d(kernel_size=(2,2))
self.dcBlock3 = DCBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, apply_batchnorm=True)
self.resPath3 = ResPath(in_channels=128, out_channels=128, residuals=2, apply_batchnorm=True)
self.maxPool3 = torch.nn.MaxPool2d(kernel_size=(2,2))
self.dcBlock4 = DCBlock(in_channels=128, out_channels=256, kernel_size=3, stride=1, apply_batchnorm=True)
self.resPath4 = ResPath(in_channels=256, out_channels=256, residuals=1, apply_batchnorm=True)
self.maxPool4 = torch.nn.MaxPool2d(kernel_size=(2,2))
self.dcBlock5 = DCBlock(in_channels=256, out_channels=512, kernel_size=3, stride=1, apply_batchnorm=True)
self.cTranBlock1 = torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
self.dcBlock6 = DCBlock(in_channels=512, out_channels=256, kernel_size=3, stride=1, apply_batchnorm=True)
self.cTranBlock2 = torch.nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
self.dcBlock7 = DCBlock(in_channels=256, out_channels=128, kernel_size=3, stride=1, apply_batchnorm=True)
self.cTranBlock3 = torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
self.dcBlock8 = DCBlock(in_channels=128, out_channels=64, kernel_size=3, stride=1, apply_batchnorm=True)
self.cTranBlock4 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)
self.dcBlock9 = DCBlock(in_channels=64, out_channels=32, kernel_size=3, stride=1, apply_batchnorm=True)
self.lastBlock = torch.nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(1,1))
def forward(self, x):
### -------- Encoder block --------- ###
dc1Out = self.dcBlock1(x)
x = self.maxPool1(dc1Out)
res1 = self.resPath1(dc1Out)
dc2Out = self.dcBlock2(x)
x = self.maxPool2(dc2Out)
res2 = self.resPath2(dc2Out)
dc3Out = self.dcBlock3(x)
x = self.maxPool3(dc3Out)
res3 = self.resPath3(dc3Out)
dc4Out = self.dcBlock4(x)
x = self.maxPool4(dc4Out)
res4 = self.resPath4(dc4Out)
### -------- Bottleneck block --------- ###
dc5Out = self.dcBlock5(x)
### -------- Decoder block --------- ###
cTran1Out = self.cTranBlock1(dc5Out)
merge1 = torch.cat([cTran1Out, res4], axis=1)
dc6Out = self.dcBlock6(merge1)
cTran2Out = self.cTranBlock2(dc6Out)
merge2 = torch.cat([cTran2Out, res3], axis=1)
dc7Out = self.dcBlock7(merge2)
cTran3Out = self.cTranBlock3(dc7Out)
merge3 = torch.cat([cTran3Out, res2], axis=1)
dc8Out = self.dcBlock8(merge3)
cTran4Out = self.cTranBlock4(dc8Out)
merge4 = torch.cat([cTran4Out, res1], axis=1)
dc9Out = self.dcBlock9(merge4)
out = self.lastBlock(dc9Out)
return out
class DCBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, apply_batchnorm=True):
super(DCBlock, self).__init__()
self.apply_batchnorm = apply_batchnorm
self.kernel_size = kernel_size
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
self.lx_conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=round(out_channels/6), kernel_size=kernel_size, stride=stride, padding='same')
self.act1_lx = torch.nn.ReLU()
self.lx_conv2 = torch.nn.Conv2d(in_channels=round(out_channels/6), out_channels=round(out_channels/3), kernel_size=kernel_size, stride=stride, padding='same')
self.act2_lx = torch.nn.ReLU()
self.lx_conv3 = torch.nn.Conv2d(in_channels=round(out_channels/3), out_channels=round(out_channels/2), kernel_size=kernel_size, stride=stride, padding='same')
self.act3_lx = torch.nn.ReLU()
self.rx_conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=round(out_channels/6), kernel_size=kernel_size, stride=stride, padding='same')
self.act1_rx = torch.nn.ReLU()
self.rx_conv2 = torch.nn.Conv2d(in_channels=round(out_channels/6), out_channels=round(out_channels/3), kernel_size=kernel_size,stride=stride, padding='same')
self.act2_rx = torch.nn.ReLU()
self.rx_conv3 = torch.nn.Conv2d(in_channels=round(out_channels/3), out_channels=round(out_channels/2), kernel_size=kernel_size,stride=stride, padding='same')
self.act3_rx = torch.nn.ReLU()
if self.apply_batchnorm:
self.bn1 = torch.nn.BatchNorm2d(num_features=round(out_channels))
self.bn2 = torch.nn.BatchNorm2d(num_features=round(out_channels))
self.bn3 = torch.nn.BatchNorm2d(num_features=round(out_channels))
self.actFinal = torch.nn.ReLU()
def forward(self, inp):
# left path
x = self.lx_conv1(inp)
lx_1 = self.act1_lx(x)
x = self.lx_conv2(lx_1)
lx_2 = self.act2_lx(x)
x = self.lx_conv3(lx_2)
lx_3 = self.act2_lx(x)
# right path
x = self.rx_conv1(inp)
rx_1 = self.act1_rx(x)
x = self.rx_conv2(rx_1)
rx_2 = self.act2_rx(x)
x = self.rx_conv3(rx_2)
rx_3 = self.act2_rx(x)
# concatenation
conc_lx = torch.cat([lx_1, lx_2, lx_3], axis=1)
if self.apply_batchnorm:
conc_lx = self.bn1(conc_lx)
conc_rx = torch.cat([rx_1, rx_2, rx_3], axis=1)
if self.apply_batchnorm:
conc_rx = self.bn2(conc_rx)
# summing up
x = conc_lx.add(conc_rx)
x = self.actFinal(x)
if self.apply_batchnorm:
x = self.bn3(x)
return x
class ResPath(torch.nn.Module):
def __init__(self, in_channels, out_channels, residuals, apply_batchnorm=True):
super(ResPath, self).__init__()
self.apply_batchnorm = apply_batchnorm
self.residuals = residuals
self.in_channels = in_channels
self.out_channels = out_channels
self.conv3_list = torch.nn.ModuleList()
self.conv1_list = torch.nn.ModuleList()
self.bn_list = torch.nn.ModuleList()
self.conv3_list.append(torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=3, padding='same'))
self.conv3_list.append(torch.nn.ReLU())
self.conv1_list.append(torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding='same'))
self.conv1_list.append(torch.nn.ReLU())
if self.apply_batchnorm:
self.bn_list.append(torch.nn.BatchNorm2d(num_features=out_channels))
for i in range(self.residuals-1):
self.conv3_list.append(torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=3, padding='same'))
self.conv3_list.append(torch.nn.ReLU())
self.conv1_list.append(torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding='same'))
self.conv1_list.append(torch.nn.ReLU())
if self.apply_batchnorm:
self.bn_list.append(torch.nn.BatchNorm2d(num_features=out_channels))
self.act = torch.nn.ReLU()
def forward(self, x):
for i in range(self.residuals):
conv = self.conv3_list[i](x)
res = self.conv1_list[i](x)
sumed = conv.add(res)
x = self.act(sumed)
if self.apply_batchnorm:
x = self.bn_list[i](x)
return x
and in the training loop:
inputs = inputs.to(device)
masks = masks.to(device)
# Zero the gradients for every batch
#optimizer.zero_grad()
# Make predictions for this batch
with torch.cuda.amp.autocast():
outputs = model(inputs)
#Compute the loss and its gradients
loss = criterion(outputs.float(), masks)
scaler.scale(loss).backward()
# Adjust learning weights
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
Do you have any suggestion of what part of code I should inspect in order to understand the problem? Thank you!
AMP: 210 seconds
SP: 221 seconds