Conv layers work very long

Conv layers work much longer, tens of times longer than linear layers (on gpu). What is the reason for this? Are there ways to accelerate ultra-precise networks? Maybe there is an opportunity to collapse not in steps, but in a bundle? (It’s not about the batch)

The operations in a conv layer might include e.g. additional im2col ops or fft based operations.
How and which layers did you profile on the GPU?

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv11 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv11.weight)
        self.relu11c = nn.PReLU()
        self.bn11c = nn.BatchNorm2d(32, affine=True)
        
        self.conv12 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv12.weight)
        self.relu12c = nn.PReLU()
        self.bn12c = nn.BatchNorm2d(32, affine=True)
        
        self.conv13 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 
        nn.init.xavier_uniform_(self.conv13.weight)
        self.relu13c = nn.PReLU()
        self.bn13c = nn.BatchNorm2d(32, affine=True)
        self.mpool1 = nn.MaxPool2d(kernel_size=2, stride=2) #80x16
        
        self.conv21 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) #80x16
        nn.init.xavier_uniform_(self.conv21.weight)
        self.relu21c = nn.PReLU()
        self.bn21c = nn.BatchNorm2d(64, affine=True)
        
        self.conv22 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv22.weight)
        self.relu22c = nn.PReLU()
        self.bn22c = nn.BatchNorm2d(64, affine=True)
        
        self.conv23 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 
        nn.init.xavier_uniform_(self.conv23.weight)
        self.relu23c = nn.PReLU()
        self.bn23c = nn.BatchNorm2d(64, affine=True)
        self.mpool2 = nn.MaxPool2d(kernel_size=2, stride=2) #40x8
        
        self.conv31 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv31.weight)
        self.relu31c = nn.PReLU()
        self.bn31c = nn.BatchNorm2d(128, affine=True)
        
        self.conv32 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv32.weight)
        self.relu32c = nn.PReLU()
        self.bn32c = nn.BatchNorm2d(128, affine=True)
        
        self.conv33 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv33.weight)
        self.relu33c = nn.PReLU()
        self.bn33c = nn.BatchNorm2d(128, affine=True)
        self.mpool3 = nn.MaxPool2d(kernel_size=2, stride=2) #20x4
        
        self.fc7 = nn.Linear(10240, 700)
        nn.init.xavier_uniform_(self.fc7.weight)
        self.relu7 = nn.PReLU()
        self.bn7 = nn.BatchNorm2d(1, affine=True)
        
        self.fc8 = nn.Linear(700, 1)
        nn.init.xavier_uniform_(self.fc8.weight)
        self.tan = nn.Hardtanh()
        
    def forward(self, x):
        out = self.conv11(x)
        out = self.relu11c(out)
        out = self.bn11c(out)
        res = out
        
        out = self.conv12(out)
        #out = self.relu12c(out)
        out = self.bn12c(out)
        out += res
        
        out = self.conv13(out)
        out = self.relu13c(out)
        out = self.bn13c(out)
        out = self.mpool1(out)
        
        out = self.conv21(out)
        out = self.relu21c(out)
        out = self.bn21c(out)
        res = out
        
        out = self.conv22(out)
        #out = self.relu22c(out)
        out = self.bn22c(out)
        out += res
        
        out = self.conv23(out)
        out = self.relu23c(out)
        out = self.bn23c(out)
        out = self.mpool2(out)
        
        
        out = self.conv31(out)
        out = self.relu31c(out)
        out = self.bn31c(out)
        res = out
        
        out = self.conv32(out)
        #out = self.relu32c(out)
        out = self.bn32c(out)
        out += res
        
        out = self.conv33(out)
        out = self.relu33c(out)
        out = self.bn33c(out)
        out = self.mpool3(out)
        
        
        out = out.contiguous().view(batch,1,1,10240)
        
        out = self.fc7(out)
        out = self.relu7(out)
        out = self.bn7(out)
        
        out = self.fc8(out)
        
        out = out.contiguous().view(batch,1)
        #print('out ',out.size())
        
        return out
device = torch.device("cuda:0")
net = Net()
net.to(device)
...
w = torch.from_numpy(w).unsqueeze(1).to(device)
outputs = net(w)
...

Thanks for the code.
How did you profile the code and which assumptions did you make?

You could try to use the JIT to fuse some operations.