Hi everyone,
I have created a distributed model using 2 machines, rank = 0, rank = 1 respectively. The model appears to be trained because it converges to the expected result. I save the model using the command “torch.save (model.state_dict (), ‘RPC1.pth’)”
The problem is when I want to load the model (model.load_state_dict (torch.load (‘RPC1.pth’))) to make a prediction. The loaded model does not appear to have been saved. And the prediction is wrong.
How can I save the trained model using torch.distributed RPC and RRef?
The python file is:
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef
import torch.distributed.autograd as dist_autograd
from torch.distributed.optim import DistributedOptimizer
def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
“”"
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and fetch back the
result using RPC
“”"
return rpc.rpc_sync(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)
def _parameter_rrefs(module):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
“”"
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
We create the dataset and an iterable.
class my_points(data.Dataset):
def init(self, n_samples):
self.n_tuples = int(n_samples/4)
self.n_samples = self.n_tuples * 4
pd_data = np.tile(np.array([[0.,0.,0.],[1.,1.,0.],[1.,0.,1.],[0.,1.,1.]]), (self.n_tuples, 1) ) # data
self.data = pd_data[:, 0:2] # 1st and 2nd columns → x,y
self.target = pd_data[:, 2:] # 3nd column → label
def __len__(self): # Length of the dataset.
return self.n_samples
def __getitem__(self, index): # Function that returns one point and one label.
return torch.Tensor(self.data[index]), torch.Tensor(self.target[index])
class rref_in(nn.Module):
def init(self, n_in=2, n_hidden=4, n_out=2):
super(rref_in, self).init()
self.n_in = n_in
self.n_out = n_out
self.n_hidden = n_hidden
self.h = nn.Linear(self.n_in, self.n_hidden, bias=True)
self.fc1 = nn.Linear(self.n_hidden, self.n_out, bias=True)
def forward(self, x):
x = F.sigmoid(self.h(x))
x = F.sigmoid(self.fc1(x))
return x
We build a model with two inputs and one output.
class my_model(nn.Module):
def init(self, ps, n_in=2, n_hidden=4, n_out=2, dim=1):
super(my_model, self).init()
self.n_in = n_in
self.n_out = n_out
self.n_hidden = n_hidden
self.rref_in = rpc.remote(ps, rref_in, args=(n_in, n_hidden, n_out)) # setup remotely
self.out = nn.Softmax(dim=dim)
def forward(self, x):
x = _remote_method(rref_in.forward, self.rref_in, x)
x=self.out(x)
return x
def parameter_rrefs(self):
remote_params = []
# create RRefs for local parameters
remote_params.extend(_remote_method(_parameter_rrefs, self.rref_in))
remote_params.extend(_parameter_rrefs(self.out))
return remote_params
def _run_trainer(data):
n_classes = 2
# We create the dataloader.
# 100 iteraciones → n_points = 2000 / batch_size = 20
my_data = my_points(2000)
batch_size = 20
my_loader = data.DataLoader(my_data,batch_size=batch_size,num_workers=1)
# Model.
# Now, we create the model, the loss function or criterium and the optimizer
model = my_model(ps= 'ps',n_in=2, n_hidden=2, n_out=2, dim=1)
# print(model)
criterium = nn.CrossEntropyLoss()
# setup distributed optimizer
optimizer = DistributedOptimizer( torch.optim.SGD, model.parameter_rrefs(), lr=0.06, momentum=0.9)
# Supervised Taining.
epochs=10
max_iter = my_loader.__len__()*epochs
cost = np.zeros((max_iter, 1))
ucost = np.zeros((max_iter, 1))
i, c_ant, beta = 0, 0, 0.99
for ep in range(epochs):
for k, (data, target) in enumerate(my_loader):
with dist_autograd.context() as context_id:
# Definition of inputs as variables for the net.
# requires_grad is set False because we do not need to compute the
# derivative of the inputs.
data = Variable(data, requires_grad=False)
target = Variable(target.long(), requires_grad=False)
# Feed forward.
pred = model(data)
# Loss calculation.
loss = criterium(pred, target.view(-1))
# run distributed backward pass
dist_autograd.backward(context_id, [loss])
# run distributed optimizer
optimizer.step(context_id)
cost[i] = loss.item()
c_act = (1 - beta) * cost[i] + beta * c_ant
ucost[i] = c_act / (1 - beta ** (i + 1))
c_ant = c_act
i += 1
print('Loss {:.4f} at epoch {:d}'.format(loss.item(), ep + 1))
# Now, we plot the results.
# Plot the loss C.
plt.plot(range(max_iter), cost, color='steelblue', marker='o')
plt.plot(range(max_iter), ucost,'c-', linewidth=3)
plt.xlabel("Iterations")
plt.ylabel("Cost (loss)")
plt.show(block=True)
colors = ['r','b','g','y']
points = data.numpy()
# Ground truth last batch.
target = target.numpy()
for k in range(n_classes):
select = target[:,0]==k
p = points[select,:]
plt.scatter(p[:,0],p[:,1],facecolors=colors[k])
# Predictions last batch.
pred = pred.exp().detach() # exp of the log prob = probability.
_, index = torch.max(pred,1) # index of the class with maximum probability.
pred = pred.numpy()
index = index.numpy()
for k in range(n_classes):
select = index==k
p = points[select,:]
plt.scatter(p[:,0],p[:,1],s=60,marker='s',edgecolors=colors[k],facecolors='none')
plt.show()
torch.save(model.state_dict(), 'RPC1.pth')
if name == ‘main’:
rank = 0 # rank = 0 (train) rank = 1 (server)
world_size = 2
os.environ[“MASTER_ADDR”] = ‘192.168.0.14’ #ip machine rank 0
os.environ[“MASTER_PORT”] = str(26500)
os.environ[“WORLD_SIZE”] = str(world_size)
os.environ[“RANK”] = str(rank)
os.environ[‘TP_SOCKET_IFNAME’] = ‘wlp2s0’ #name ip port 192.168.0.14 (check ifconfig ubuntu, ipconfig windows)
os.environ[“GLOO_SOCKET_IFNAME”] = “wlp2s0” #name ip port 192.168.0.14 (check ifconfig ubuntu, ipconfig windows)
if rank == 0:
print('rank: ', rank, world_size)
rpc.init_rpc("trainer", rank=rank, world_size=world_size)
_run_trainer(data)
else:
print('rank: ', rank, world_size)
rpc.init_rpc("ps", rank=rank, world_size=world_size)
# parameter server does nothing
pass
rpc.shutdown()