KAN - Model not replicated to all GPUs with nn.DataParallel()

Hi, I am training Kolmogorov Arnold Networks (GitHub - KindXiaoming/pykan: Kolmogorov Arnold Networks) and want to use the data parallel method to speed up training.
However, I noticed that the data tensors are getting assigned to different GPUs, but the model is not replicated to all GPUs. I am not sure how to proceed next!

code snippet -

student_model = KQN_Agent(in_channels, num_actions, 0)
student_model = student_model.cuda()
student_model = nn.DataParallel(student_model, device_ids=list(range(num_gpus)))

# training definition
    ...
    student_logits = student_model(obs_tensor)
    ....

error -

Epoch 1:   0%|                                                                                                                                                         | 0/7812 [00:02<?, ?it/s]
 line 305, in train
    student_logits = student_model(obs_tensor)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 193, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 212, in parallel_apply
    return parallel_apply(
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 126, in parallel_apply
    output.reraise()
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/KAN-1/expt4/src/Agent.py", line 57, in forward
    actions = self.network(x)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/KAN-1/expt4/src/Network.py", line 62, in forward
    actions = self.network(x)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rjaditya/MSP/lib/python3.10/site-packages/kan/MultKAN.py", line 816, in forward
    x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Could you post the model definition here, please?

Model -

class KQN_Agent(nn.Module):
    
    def __init__(self, in_channels, num_actions, epsilon):
        super().__init__()
        
        self.in_channels = in_channels
        self.num_actions = num_actions
        self.network = KQN(in_channels, num_actions)
        
        self.eps = epsilon
    
    def forward(self, x):
        actions = self.network(x)
        return actions

class KQN(nn.Module):
    # Architecture matching the saved Student model
    
    def __init__(self, in_channels, num_actions):
        super().__init__()
        
        network = [
            torch.nn.Conv2d(in_channels, 10, kernel_size=5, stride=3, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(10, 15, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(15, 20, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Sequential(
                # nn.Linear(20*4*4, 20),
                KAN(width=[20*4*4, 8, 8, num_actions], grid = 15, k = 3)
            )
        ]
        
        self.network = nn.Sequential(*network)
    
    def forward(self, x):
        actions = self.network(x)
        return actions

Thank you!
I cannot reproduce any issue using:

import torch 
import torch.nn as nn


class KQN_Agent(nn.Module):
    
    def __init__(self, in_channels, num_actions, epsilon):
        super().__init__()
        
        self.in_channels = in_channels
        self.num_actions = num_actions
        self.network = KQN(in_channels, num_actions)
        
        self.eps = epsilon
    
    def forward(self, x):
        actions = self.network(x)
        return actions

class KQN(nn.Module):
    # Architecture matching the saved Student model
    
    def __init__(self, in_channels, num_actions):
        super().__init__()
        
        network = [
            torch.nn.Conv2d(in_channels, 10, kernel_size=5, stride=3, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(10, 15, kernel_size=3, stride=2, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(15, 20, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            #nn.Sequential(
                # nn.Linear(20*4*4, 20),
            #    KAN(width=[20*4*4, 8, 8, num_actions], grid = 15, k = 3)
            #)
        ]
        
        self.network = nn.Sequential(*network)
    
    def forward(self, x):
        actions = self.network(x)
        return actions


in_channels = 1
num_actions = 1
student_model = KQN_Agent(in_channels, num_actions, 0)
student_model = student_model.cuda()
student_model = nn.DataParallel(student_model, device_ids=list(range(8)))
x = torch.randn(8, 1, 224, 224).cuda()

out = student_model(x)
print(out.sum())
# tensor(2466.0723, device='cuda:0', grad_fn=<SumBackward0>)

but also needed to remove the undefined KAN module.

But that is needed!

to use KAN -

!pip install pykan
from kan import KAN 

pykan documentation - pykan ยท PyPI