I have a requirement that I need to compute the loss of a model with same inputs and targets but with different values of model parameters. As I will be iterating over many parameter values, I need to evaluate the model in parallel with CPU cores and hence I chose to use torch.multiprocessing. But the code throws TypeError: can’t pickle torch.dtype objects.
Here is the minimal code which I used:
from torch.multiprocessing import Pool
import numpy as np
import torch
import torch.nn as nn
# set up the network architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(1,2)
def forward(self, x):
with torch.no_grad():
return x
class SoftmaxLoss:
def __init__(self, model):
self.model = model
# loss criterion
self.criterion = torch.nn.CrossEntropyLoss()
def __call__(self, parameter_name, parameter_value, parameter_type, inputs, targets):
# build a model with this parameter
self.model.set_param_value(parameter_name, parameter_type, parameter_value)
# do a forward pass with the built model
outputs = self.model.forward(inputs)
# compute loss and return
loss = self.criterion(outputs, targets)
return loss.data.numpy()
class Workers:
def __init__(self, function, param_name, param_type, inputs, targets, n_workers=4):
self.n_workers = n_workers
self.inputs = inputs
self.targets = targets
self.param_name = param_name
self.param_type = param_type
# function to parallelize
self.fn = function
def fn_wrapper(self, x):
return self.fn(self.param_name, x, self.param_type, self.inputs, self.targets)
def parallelize(self, array):
p = Pool(self.n_workers)
results = p.map(self.fn_wrapper, array)
print(results)
class PyTorchModel:
def __init__(self, net):
# net is the current optimized network
self.net = net
# parameters
self.parameters = list(torch.nn.Module.named_parameters(self.net))
def forward(self, input):
output = self.net(input)
return output
def set_param_value(self, name, type, value):
obj = self.net
parts = name.split(".")
for attr in parts[:-1]:
obj = getattr(obj, attr)
setattr(obj, parts[-1], torch.from_numpy(value).type(type))
def get_param_value(self, name):
refs = name.split(".")
last = self.net
for ref in refs:
last = getattr(last, ref)
return np.copy(last.data.numpy()), last.data.dtype
if __name__ == '__main__':
# net
net = Net()
# set up the model
model = PyTorchModel(net=net)
# the objective function
objective = SoftmaxLoss(model=model)
# sample inputs and targets
inputs = np.array([1,2]).reshape((2,1))
targets = np.array([1,0]).reshape((2,1))
# set up the worker
_, type = model.get_param_value("fc.weight")
work = Workers(objective, "fc.weight", type, inputs, targets)
# list of parameters to evaluate
params = list(np.arange(1e+1))
work.parallelize(params)