Hi, when I use torch.distributed.rpc to implement pipeline parallelism for transformer-based inference, the memory consumption increases with each forward pass. The code for 2 nodes is like this,
First, I define two classes for transformer shard
import os
import sys
import threading
import time
import torch
import torch.nn as nn
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef
from transformers import ViTFeatureExtractor, ViTForImageClassification
#####################################
# Define Transformer Shard #
####################################
class TransformerShard1(nn.Module):
def __init__(self, device, config,num_layers):
super().__init__()
# self.config = config
self.model = ViTForImageClassification.from_pretrained(config)
self.model.vit.encoder.layer = nn.Sequential(*[self.model.vit.encoder.layer[i] for i in range(num_layers)])
self.model.vit= nn.Sequential(*list(self.model.vit.children())[:-1])
self.model = nn.Sequential(*list(self.model.children())[:-1])
self._lock = threading.Lock()
self.device = device
def forward_kernel(self, x):
x = self.model(x).to_tuple()[0]
end = time.time()
return x
@torch.no_grad()
def forward(self, pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):
x = pixel_values.to_here().to(self.device)
with self._lock:
x = self.forward_kernel(x)
return x.cpu()
class TransformerShard2(nn.Module):
def __init__(self,device, config, num_layers):
super().__init__()
self.model = ViTForImageClassification.from_pretrained(config)
self.model.vit.encoder.layer = nn.Sequential(*[self.model.vit.encoder.layer[i] for i in range(num_layers, 2*num_layers)])
self.model.vit= nn.Sequential(*list(self.model.vit.children())[1:-1])
self.model = nn.Sequential(*list(self.model.children())[:-1])
self._lock = threading.Lock()
self.device = device
def forward_kernel(self, x):
x = self.model(x)[0]
return x
@torch.no_grad()
def forward(self, x_rref=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):
x = x_rref.to_here().to(self.device)
with self._lock:
x = self.forward_kernel(x)
return x.cpu()
Then I stitch them into one class for forwarding:
class DistViT(nn.Module):
def __init__(
self,
split_size,
workers,
config,
num_layers,
*args, **kwargs
):
super().__init__()
self.split_size = split_size # for microbatch
self.num_layers = num_layers
self.p1_rref = rpc.remote(
workers[0],
TransformerShard1,
args = ("cpu", config, int(num_layers/2)) + args,
kwargs = kwargs
)
self.p2_rref = rpc.remote(
workers[1],
TransformerShard2,
args = ("cpu", config, int(num_layers/2)) + args,
kwargs = kwargs
)
def forward(self, pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):
out_futures = []
start = time.time()
id = 0
for x in iter(pixel_values.split(self.split_size, dim=0)):
x_rref = RRef(x)
y_rref = self.p1_rref.remote().forward(x_rref)
z_fut = self.p2_rref.rpc_async().forward(y_rref)
out_futures.append(z_fut)
torch.futures.wait_all(out_futures)
return out_futures
def parameter_rrefs(self):
remote_params = []
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
return remote_params
Then run RPC process like this:
######################################
# Run RPC Processes #
######################################
config = 'google/vit-base-patch16-224'
num_layers = 12
num_batches = 1
batch_size = 256
img = torch.randn(3, 384, 384)
imgs = [img for i in range(batch_size)]
feature_extractor = ViTFeatureExtractor.from_pretrained(config)
def run_master(split_size):
# put the two model parts on worker1 and worker2 respectively
print("Run mastering \n")
for si in range(len(split_size)):
model = DistViT(split_size[si], ["worker0", "worker1"], config, num_layers)
inputs = feature_extractor(images=imgs, return_tensors="pt")
for i in range(num_batches):
# generate random inputs and labels
outputs = model(**inputs)
def run_worker(rank, world_size, num_split):
# run on local host
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
# Higher timeout is added to accommodate for kernel compilation time in case of ROCm.
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256,rpc_timeout=3000)
if rank == 0:
rpc.init_rpc(
"worker0",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
run_master(num_split)
else:
rpc.init_rpc(
"worker1",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
pass
rpc.shutdown()
The main function:
if __name__=="__main__":
world_size = 2
rank=int(sys.argv[1])
num_split=[8]
print(f"{config}, {num_layers}, {num_split}")
tik = time.time()
run_worker(rank, world_size, num_split)
tok = time.time()
print(f"Total program execution time = {tok - tik}")
It needs to import transformers package with the command:
pip install transformers
The whole script is uploaded in ubuntu pastebin
I run it on my macOS with PyTorch 1.8.1 CPU-only, the running command is like this:
on the first terminal:
python pipeline_parallelism.py 0
on the second terminal:
python pipeline_parallelism.py 1
I use top command to check the memory usage, and after each forward, the memory increases about 3MB until OOM or the RPC shutdown. Could anyone help me to fix or find the problem? Thank you very much!