I want to do it but my situation is that I only have cpus (I have up to 112). I tried it but I also get the same bug. How do I get around it? @albanD I made a totally self contained example:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torch.multiprocessing import Pool
class SimpleDataSet(Dataset):
def __init__(self, Din, num_examples=23):
self.x_dataset = [torch.randn(Din) for _ in range(num_examples)]
# target function is x*x
self.y_dataset = [x**2 for x in self.x_dataset]
def __len__(self):
return len(self.x_dataset)
def __getitem__(self, idx):
return self.x_dataset[idx], self.y_dataset[idx]
def get_loss(args):
x, y, model = args
y_pred = model(x)
criterion = nn.MSELoss()
loss = criterion(y_pred, y)
return loss
def get_dataloader(D, num_workers, batch_size):
ds = SimpleDataSet(D)
dl = DataLoader(ds, batch_size=batch_size, num_workers=num_workers)
return dl
def train_fake_data():
num_workers = 2
Din, Dout = 3, 1
model = nn.Linear(Din, Dout).share_memory()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
batch_size = 2
num_epochs = 10
# num_batches = 5
num_procs = 5
dataloader = get_dataloader(Din, num_workers, batch_size)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(num_epochs):
for _, batch in enumerate(dataloader):
batch = [(torch.randn(Din), torch.randn(Dout), model) for _ in batch]
with Pool(num_procs) as pool:
optimizer.zero_grad()
losses = pool.map(get_loss, batch)
loss = torch.mean(losses)
loss.backward()
optimizer.step()
# scheduler
scheduler.step()
if __name__ == '__main__':
# start = time.time()
# train()
train_fake_data()
# print(f'execution time: {time.time() - start}')
Error:
Traceback (most recent call last):
File "/Users/brando/anaconda3/envs/coq_gym/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-ea57e03ba088>", line 1, in <module>
runfile('/Users/brando/ML4Coq/playground/multiprocessing_playground/multiprocessing_cpu_pytorch.py', wdir='/Users/brando/ML4Coq/playground/multiprocessing_playground')
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brando/ML4Coq/playground/multiprocessing_playground/multiprocessing_cpu_pytorch.py", line 95, in <module>
train_fake_data()
File "/Users/brando/ML4Coq/playground/multiprocessing_playground/multiprocessing_cpu_pytorch.py", line 83, in train_fake_data
losses = pool.map(get_loss, batch)
File "/Users/brando/anaconda3/envs/coq_gym/lib/python3.7/multiprocessing/pool.py", line 290, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "/Users/brando/anaconda3/envs/coq_gym/lib/python3.7/multiprocessing/pool.py", line 683, in get
raise self._value
multiprocessing.pool.MaybeEncodingError: Error sending result: '[tensor(0.5237, grad_fn=<MseLossBackward>)]'. Reason: 'RuntimeError('Cowardly refusing to serialize non-leaf tensor which requires_grad, since autograd does not support crossing process boundaries. If you just want to transfer the data, call detach() on the tensor before serializing (e.g., putting it on the queue).')'
I am sure I want to do this. How should I be doing this?
related links from research:
- Multiprocessing for-loop on CPU
- How to use multiprocessing in PyTorch? - Stack Overflow
- How to parallelize a loop over the samples of a batch - #7 by Brando_Miranda
- machine learning - How to parallelize a training loop ever samples of a batch when CPU is only available in pytorch? - Stack Overflow
- RuntimeError: Cowardly refusing to serialize non-leaf tensor which requires_grad, since autograd does not support crossing process boundaries · Issue #36457 · pytorch/pytorch · GitHub