On Parallelizing Acyclical Linear Blocks w/ Grouped Conv1d

Hi gang, been a while :slight_smile:. So I have this simple model:

class ModelSerial(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        blocks = [
            nn.Sequential(
                nn.Linear(in_features, hidden),
                nn.ReLU(inplace=True),

                nn.Linear(hidden, hidden),
                nn.ReLU(inplace=True),

                nn.Linear(hidden, 1),
            )
            for _ in range(out_features)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        out = torch.cat([b(x) for b in self.blocks], dim=-1)
        return out

Notice how each sequential block’s output is not chained in the forward pass. Rather, each is executed on the input separately, and the results are simply concatenated. As one would image, it runs extremely slow, especially as the number of output features increases or the depth of the network increases. This is due the Python for-loop. I had considered trying torch script jit, but that wasn’t an option for this particular engagement. So next best thing? Try and parallelize it with grouped convolutions. Here’s my attempt:

class ModelParallel(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        
        self.out_features = out_features
        
        self.block = nn.Sequential(
            nn.Conv1d(
                in_features * self.out_features,
                hidden * self.out_features,
                kernel_size=1,
                groups=self.out_features
            ),
            nn.ReLU(inplace=True),

            nn.Conv1d(
                hidden * self.out_features,
                hidden * self.out_features,
                kernel_size=1,
                groups=self.out_features
            ),
            nn.ReLU(inplace=True),

            nn.Conv1d(
                hidden * self.out_features,
                out_features,
                kernel_size=1,
                groups=self.out_features
            )
        )

    def forward(self, x):
        # torch.Size([B, 150])
        bs = x.shape[0]
        
        # torch.Size([B, 150, 1])
        x = x[:, :, None]

        # torch.Size([B, 150, 50])
        x = x.expand(bs, x.shape[1], self.out_features)

        # torch.Size([B, 7500, 1]) because conv1d works on [B C L]
        x = x.reshape(bs, -1, 1)

        # torch.Size([B, 50, 1])
        x = self.block(x)

        # torch.Size([B, 50])
        x = x.squeeze(dim=-1)
        
        return x

The model compiles. Input size and output size are producing what is expected. But it doesn’t train at all. Not even close. Validation and Train loss are completely out of sync, it’s a mess.

I believe my lack of understanding how groups are computed are what are messing things up. So I even tried re-arranging the setup as follows in case the groups are ‘interlaced’ as opposed to ‘chunked’:

########
# Change:

# torch.Size([B, 150, 1])
x = x[:, :, None]

# torch.Size([B, 150, 50])
x = x.expand(bs, x.shape[1], self.out_features)

########
# Into:

# torch.Size([B, 1, 150])
x = x[:, None, :]

# torch.Size([B, 50, 150])
x = x.expand(bs, self.out_features, x.shape[2])

But unfortunately, that still didn’t work (train). Epoch time was a hell of a lot faster though, as desired. My example tests:

ms = ModelSerial(
    in_features=150,
    hidden=10,
    out_features=50
)

mp = ModelSerial(
    in_features=150,
    hidden=10,
    out_features=50
)

I wouldn’t expect the results to be exact, I mean, linear and conv are using different weight initialization schemes even if deterministic is set. But I would at least expect the parallel model to train. PyTorch senpais, can you please provide me with some guidance on how I can architect this problem to execute in parallel? Thank you!!

Your input-to-hidden layer should use 1 group (but bigger output width), otherwise you split inputs.

As I side node, I just switched from this approach to batched matrix multiplications, as there are some performance issues with grouped convolutions in cudnn.

I also found that creating block-diagonal weight matrices can be better that bmm, as the latter creates an expanded copy of the weight tensor. I believe the best approach is shape (or sparsity) dependent.

Thank you for the hints. I tried to do the BMM implementation because the inputs are dense continuous values. I think I got close, but…

class ModelParallel2(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        
        self.out_features = out_features
        
        
        self.block1_bias = nn.Parameter(torch.ones(hidden))
        self.block1 = nn.Parameter(torch.ones(
            1, in_features, hidden
        ))
        
        self.block2_bias = nn.Parameter(torch.ones(hidden))
        self.block2 = nn.Parameter(torch.ones(
            1, hidden, hidden
        ))
        
        self.block3_bias = nn.Parameter(torch.ones(1))
        self.block3 = nn.Parameter(torch.ones(
            1, hidden, 1
        ))
        
        
        # nn.init.xavier_uniform_(self.block1, gain=nn.init.calculate_gain('relu'))
        # nn.init.xavier_uniform_(self.block2, gain=nn.init.calculate_gain('relu'))
        # nn.init.xavier_uniform_(self.block3, gain=nn.init.calculate_gain('relu'))

        nn.init.kaiming_uniform_(self.block1, a=np.sqrt(6))
        nn.init.kaiming_uniform_(self.block2, a=np.sqrt(6))
        nn.init.kaiming_uniform_(self.block3, a=np.sqrt(6))
        
        bound_a = 1 / np.sqrt(in_features)
        bound_b = 1 / np.sqrt(hidden)
        nn.init.uniform_(self.block1_bias, -bound_a, bound_a)
        nn.init.uniform_(self.block2_bias, -bound_b, bound_b)
        nn.init.uniform_(self.block3_bias, -bound_b, bound_b)
        
    def forward(self, x):
        # torch.Size([32, 130])
        bs = x.shape[0]
        
        # torch.Size([32, 1, 130])
        x = x[:, None, :]

        # torch.Size([32, 45, 130])
        x = x.expand(bs, self.out_features, x.shape[2])

        # torch.Size([32, 45, 10])
        x = x @ self.block1
        x = F.relu(x, inplace=True)
        
        # torch.Size([32, 45, 10])
        x = x @ self.block2
        x = F.relu(x, inplace=True)
        
        # torch.Size([32, 45, 1])
        x = x @ self.block3

        # torch.Size([32, 45])
        x = x.squeeze(dim=-1)
        
        return x

Inputting a random tensor and initializing the model with:

mp = ModelParallel2(
    in_features=130,
    hidden=10,
    out_features=45
)

I get a torch.Size([32, 45]) sized output, which is correct and desirable. However, the output features are all duplicated:

tensor([[-0.0014, -0.0014, -0.0014,  ..., -0.0014, -0.0014, -0.0014],
        [ 0.0002,  0.0002,  0.0002,  ...,  0.0002,  0.0002,  0.0002],
        [ 0.0001,  0.0001,  0.0001,  ...,  0.0001,  0.0001,  0.0001],
        ...,
        [-0.0023, -0.0023, -0.0023,  ..., -0.0023, -0.0023, -0.0023],
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        [-0.0006, -0.0006, -0.0006,  ..., -0.0006, -0.0006, -0.0006]],
       grad_fn=<SqueezeBackward1>)

From the docs I found out that torch.bmm “This function does not broadcast. For broadcasting matrix products, see torch.matmul().” But even replacing x = x @ self.block1 with x = torch.matmul(x, self.block1) doesn’t change the behavior of the duplicates.

It is a bit tricky, requiring shape manipulations

x shape should be like [batch_dims, ngroups, 1, group_size_in]
weight shape at init: [ngroups, group_size_out, group_size_in]
weight shape in forward(): [ngroups, group_size_in, group_size_out]

this thing about weights allows to do init like:
nn.init.kaiming_(weight.view(-1,group_size_in))

then matmul internally reshapes this into bmm format, does bmm and finally outputs [batch_dims,ngroups,1,group_size_out]

PS group conv may actually be ok for a small group count, without this mess, not sure

Got it working with the group conv with the using 1 group adjustment for the first layer! It’s blazing fast. Thank you so much, my iteration time just dropped down to 4 secs / epoch from 4 min / epoch!!