Hi, I am training Kolmogorov Arnold Networks (GitHub - KindXiaoming/pykan: Kolmogorov Arnold Networks) and want to use the data parallel method to speed up training.
However, I noticed that the data tensors are getting assigned to different GPUs, but the model is not replicated to all GPUs. I am not sure how to proceed next!
code snippet -
student_model = KQN_Agent(in_channels, num_actions, 0)
student_model = student_model.cuda()
student_model = nn.DataParallel(student_model, device_ids=list(range(num_gpus)))
# training definition
...
student_logits = student_model(obs_tensor)
....
error -
Epoch 1: 0%| | 0/7812 [00:02<?, ?it/s]
line 305, in train
student_logits = student_model(obs_tensor)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 193, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 212, in parallel_apply
return parallel_apply(
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 126, in parallel_apply
output.reraise()
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
output = module(*input, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/KAN-1/expt4/src/Agent.py", line 57, in forward
actions = self.network(x)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/KAN-1/expt4/src/Network.py", line 62, in forward
actions = self.network(x)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rjaditya/MSP/lib/python3.10/site-packages/kan/MultKAN.py", line 816, in forward
x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!