TypeError: can't pickle torch.dtype objects while using multiprocessing

I have a requirement that I need to compute the loss of a model with same inputs and targets but with different values of model parameters. As I will be iterating over many parameter values, I need to evaluate the model in parallel with CPU cores and hence I chose to use torch.multiprocessing. But the code throws TypeError: can’t pickle torch.dtype objects.

Here is the minimal code which I used:

from torch.multiprocessing import Pool
import numpy as np
import torch
import torch.nn as nn

# set up the network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(1,2)

    def forward(self, x):
        with torch.no_grad():
            return x


class SoftmaxLoss:
    def __init__(self, model):
        self.model = model

        # loss criterion
        self.criterion = torch.nn.CrossEntropyLoss()

    def __call__(self, parameter_name, parameter_value, parameter_type, inputs, targets):

        # build a model with this parameter
        self.model.set_param_value(parameter_name, parameter_type, parameter_value)

        # do a forward pass with the built model
        outputs = self.model.forward(inputs)

        # compute loss and return
        loss = self.criterion(outputs, targets)

        return loss.data.numpy()


class Workers:
    def __init__(self, function, param_name, param_type, inputs, targets, n_workers=4):
        self.n_workers = n_workers
        self.inputs = inputs
        self.targets = targets
        self.param_name = param_name
        self.param_type = param_type

        # function to parallelize
        self.fn = function

    def fn_wrapper(self, x):
        return self.fn(self.param_name, x, self.param_type, self.inputs, self.targets)

    def parallelize(self, array):
        p = Pool(self.n_workers)
        results = p.map(self.fn_wrapper, array)
        print(results)


class PyTorchModel:
    def __init__(self, net):
        # net is the current optimized network
        self.net = net

        # parameters
        self.parameters = list(torch.nn.Module.named_parameters(self.net))

    def forward(self, input):
        output = self.net(input)
        return output

    def set_param_value(self, name, type, value):
        obj = self.net
        parts = name.split(".")
        for attr in parts[:-1]:
            obj = getattr(obj, attr)
        setattr(obj, parts[-1], torch.from_numpy(value).type(type))

    def get_param_value(self, name):
        refs = name.split(".")
        last = self.net
        for ref in refs:
            last = getattr(last, ref)

        return np.copy(last.data.numpy()), last.data.dtype


if __name__ == '__main__':

    # net
    net = Net()

    # set up the model
    model = PyTorchModel(net=net)

    # the objective function
    objective = SoftmaxLoss(model=model)

    # sample inputs and targets
    inputs = np.array([1,2]).reshape((2,1))
    targets = np.array([1,0]).reshape((2,1))

    # set up the worker
    _, type = model.get_param_value("fc.weight")
    work = Workers(objective, "fc.weight", type, inputs, targets)

    # list of parameters to evaluate
    params = list(np.arange(1e+1))

    work.parallelize(params)

This is a known bug and we are working on it: https://github.com/pytorch/pytorch/issues/7481.
I’m not sure if there are any workaround at this time.