I am attempting to follow the TensorParallel tutortial, but my model, beyond containing the typical nn.Linear and nn.Embedding layers also contains nn.Conv2d layers. In reading through the TensorParallel documentation, I didn’t see any mention of support for that layer type, so I am curious: will TensorParallel work if I don’t wrap all modules in my model via parallelize_module?
For example, my model is something like:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# First 2D convolutional layer, taking in 1 input channel (image),
# outputting 32 convolutional features, with a square kernel size of 3
self.conv1 = nn.Conv2d(1, 32, 3, 1)
# Second 2D convolutional layer, taking in the 32 input layers,
# outputting 64 convolutional features, with a square kernel size of 3
self.conv2 = nn.Conv2d(32, 64, 3, 1)
# Designed to ensure that adjacent pixels are either all 0s or all active
# with an input probability
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
# First fully connected layer
self.fc1 = nn.Linear(9216, 128)
# Second fully connected layer that outputs our 10 labels
self.fc2 = nn.Linear(128, 10)
# x represents our data
def forward(self, x):
# Pass data through conv1
x = self.conv1(x)
# Use the rectified-linear activation function over x
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
# Run max pooling over x
x = F.max_pool2d(x, 2)
# Pass data through dropout1
x = self.dropout1(x)
# Flatten x with start_dim=1
x = torch.flatten(x, 1)
# Pass data through ``fc1``
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
# Apply softmax to x
output = F.log_softmax(x, dim=1)
return output
Is it possible to only call parallelize_module on fc1/fc2?
You should be able to do that.
Internally for parallelized module, the parameters would be represented by DTensor, instead of torch.Tensor. For each ParallelStyle (e.g. ColwiseParallel), there is the arg use_local_output. If it’s set to True (the default for most ParallelStyle), then the output from the parallelized module would be torch.Tensor and can work seamlessly with any non-parallelized module.
On a side note, for intermediate results to work with non-parallelized module, you need to make sure the sharding of activations out of a parallelized module is acceptable for the following non-parallelized module.
@tianyu Thanks for the help! Any idea if that would work with FSDP and DDP as well? I tried using DDP, as follows, but ran into the following error:
DDP Code (using the model defined above):
# create a device mesh based on the given world_size.
_world_size = 2 # TODO
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
print(f"Starting PyTorch TP example on rank {_rank}.")
assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus"
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = Net().to("cuda")
# Create an optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
# Custom parallelization plan for the model
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
},
)
tp_model = DDP(tp_model, device_mesh=device_mesh)
Which throws error:
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/real-csd/test_tp.py", line 94, in <module>
[rank0]: tp_model = DDP(tp_model, device_mesh=device_mesh)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 784, in __init__
[rank0]: self._log_and_throw(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1127, in _log_and_throw
[rank0]: raise err_type(err_msg)
[rank0]: RuntimeError: Modules with uninitialized parameters can't be used with `DistributedDataParallel`. Run a dummy forward pass to correctly initialize the modules
[rank1]: Traceback (most recent call last):
[rank1]: File "/scratch/real-csd/test_tp.py", line 94, in <module>
[rank1]: tp_model = DDP(tp_model, device_mesh=device_mesh)
[rank1]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 784, in __init__
[rank1]: self._log_and_throw(
[rank1]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1127, in _log_and_throw
[rank1]: raise err_type(err_msg)
[rank1]: RuntimeError: Modules with uninitialized parameters can't be used with `DistributedDataParallel`. Run a dummy forward pass to correctly initialize the modules