I would like to train a model which has a large number of classes, making the linear layer too large to fit on a single gpu. Using tensor parallel, how can I parallelize just the linear layer while keeping the rest of the network on each gpu like in distributed data parallel?
The model structure as shown below gives an idea of what I want to achieve.
import torch.nn as nn
import torch
import torch.distributed as dist
class DummyModel(nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
# some conv2d layer and relu
self.conv = nn.Conv2d()
self.relu = nn.ReLU()
# Large linear layer
self.linear = nn.Linear(in_channels, 1000000)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = torch.flatten(x, 1)
# Want to parallelize
x = self.linear(x)
return x
# Example train function
def train(model, device, rank, world_size):
# Initialize distributed training
dist.init_process_group(backend='nccl')
# Set device for the model
model.to(device)
# Wrap the model with DistributedDataParallel
model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Dummy input and target
input_data = torch.randn(64, 3, 32, 32).to(device)
target = torch.randint(0, 1000000, (64,)).to(device)
for epoch in range(10):
optimizer.zero_grad()
# How to tensor parallelize linear layer but make sure output is gathered and synced on all gpus?
output = model(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss={loss.item()}")
# Clean up distributed training
dist.destroy_process_group()