Tensor Parallel Single Layer

I would like to train a model which has a large number of classes, making the linear layer too large to fit on a single gpu. Using tensor parallel, how can I parallelize just the linear layer while keeping the rest of the network on each gpu like in distributed data parallel?

The model structure as shown below gives an idea of what I want to achieve.

import torch.nn as nn
import torch
import torch.distributed as dist


class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()

        # some conv2d layer and relu
        self.conv = nn.Conv2d()
        self.relu = nn.ReLU()

        # Large linear layer
        self.linear = nn.Linear(in_channels, 1000000)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        
        # Want to parallelize
        x = self.linear(x)
        return x

# Example train function
def train(model, device, rank, world_size):
    # Initialize distributed training
    dist.init_process_group(backend='nccl')

    # Set device for the model
    model.to(device)

    # Wrap the model with DistributedDataParallel
    model = nn.parallel.DistributedDataParallel(model, device_ids=[device])

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # Dummy input and target
    input_data = torch.randn(64, 3, 32, 32).to(device)
    target = torch.randint(0, 1000000, (64,)).to(device)

    for epoch in range(10):
        optimizer.zero_grad()

        # How to tensor parallelize linear layer but make sure output is gathered and synced on all gpus?
        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}: Loss={loss.item()}")

    # Clean up distributed training
    dist.destroy_process_group()

Welcome to the forums! Saw your question a few days ago but was too preoccupied to answer, but made a mental note to check back if anyone replied.

Let’s just start by saying there is no clean and simple one-line code solution to your issue. However, it is doable, with a bit of understanding of both Pytorch and how linear algebra operations work.

First off, the vanilla Linear layer is not divisible, as each tensor must be assigned to a device.

So we will need to use something a little more basic than the Linear layer and effectively reconstruct it, but split it as many ways as desired.

Please see the code below with comments, and a demonstration that the two mathematical operations are identical, if taking into account rounding errors.

import torch
import torch.nn as nn
import torch.nn.functional as F

# demonstrate an equivalent linear algebra operation with split weights
big_n = 1000000
# Create a control layer
big_layer = nn.Linear(big_n, 10)

input_data = torch.rand((12, big_n))

expected_output=big_layer(input_data)
print("Control layer output size", expected_output.size())

#test layers, assign weights to separate learnable layers, these can be assigned to different devices, can also assign a bias vector
#this would be in your __init__
layer_a_weight = nn.Parameter(big_layer.weight.data[:,:big_n//2])
layer_b_weight = nn.Parameter(big_layer.weight.data[:,big_n//2:])
layer_bias = nn.Parameter(big_layer.bias.data)
print("Split weight sizes", layer_a_weight.size(), layer_b_weight.size())

#split data
#the rest would be in your forward pass, you'd need to assign data to the respective devices you put your learnable parameters on
data_a = input_data[:,:big_n//2]
data_b = input_data[:,big_n//2:]
print("Split data sizes", data_a.size(), data_b.size())


#calculate separate matmuls
data_a = data_a@layer_a_weight.T
data_b = data_b@layer_b_weight.T
print("Intermediate split sizes", data_a.size(), data_b.size())

#add both elementwise, you can also add the bias at this point, if you are using it
data_test = data_a + data_b + layer_bias
print("Final output size", data_test.size())

#check if equivalent to expected_output
with torch.no_grad():
    print("Mean squared distance between both", F.mse_loss(data_test, expected_output))

Control layer output size torch.Size([12, 10])
Split weight sizes torch.Size([10, 500000]) torch.Size([10, 500000])
Split data sizes torch.Size([12, 500000]) torch.Size([12, 500000])
Intermediate split sizes torch.Size([12, 10]) torch.Size([12, 10])
Final output size torch.Size([12, 10])
Mean squared distance between both tensor(3.5189e-13)

As can be seen, the mean squared distance between both is tiny and just due to rounding error.

Hope this helps you accomplish what you’re attempting.

Thank you! I thought the tensor parallel module will do exactly this under the hood. My understanding of that module seems incorrect. Can you explain what that module does exactly?

You can read it here:

https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

Basically, DataParallel makes a copy of the entire model on each device. And then divides the batch size up to different devices, and combines the gradients of all copies before back propagation.

Yes, but I assume TensorParallel is different than DataParallel?

I see. That must be new. I’m not clear how that works, yet. But will definitely check it out when I get a chance.

@ptrblck can you shed some light on this or tag someone who can?