Unable to parallelize the network

I am using an existing code that has defined the neural network as below.

class VisualNet(nn.Module):
    def __init__(self, use_bn=False):
        super(VisualNet, self).__init__()
        
        # Instantiate backbone and reassemble blocks
        self.model = timm.create_model("vit_large_patch16_384", pretrained=True)
        self.model.patch_size = [16, 16]
        self.model.start_index = 1

        self.model.blocks[5].register_forward_hook(get_activation("1"))
        self.model.blocks[11].register_forward_hook(get_activation("2"))
        self.model.blocks[17].register_forward_hook(get_activation("3"))
        self.model.blocks[23].register_forward_hook(get_activation("4"))

        self.activations = activations
        
        # # We inject this function into the VisionTransformer instances so that
        # # we can use it with interpolated position embeddings without modifying the library source.
        self.model.forward_flex = types.MethodType(forward_flex, self.model)
        self.model._resize_pos_embed = types.MethodType(_resize_pos_embed, self.model)
     
    def forward(self, x):
        b, c, h, w = x.shape
        x.contiguous(memory_format=torch.channels_last)

        x = self.model.forward_flex(x)
        
        layer_1 = self.activations["1"]
        layer_2 = self.activations["2"]
        layer_3 = self.activations["3"]
        layer_4 = self.activations["4"]
        
        return x, (layer_1, layer_2, layer_3, layer_4)

The code works fine with single GPU but fails with multiple GPUs with an error displaying the weights and input are not in the same device. I am attaching the screenshot below.

Can someone tell me what can be the reason? Is it because I am using types.MethodType in model initialization?

Based on the error message it seems that this method assignment self.model.forward_flex = types.MethodType(forward_flex, self.model) is creating the issue.
Could you explain what this method is doing and if new tensors are created internally etc.?

Thanks for the response. I am mentioning below the forward_flex function.

def forward_flex(self, x):
    b, c, h, w = x.shape

    pos_embed = self._resize_pos_embed(self.pos_embed, 
                                        h // self.patch_size[1], 
                                        w // self.patch_size[0])

    B = x.shape[0]
    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
    cls_tokens = self.cls_token.expand(B, -1, -1)  
    x = torch.cat((cls_tokens, x), dim=1)

    x = x + pos_embed
    x = self.pos_drop(x)

    for blk in self.blocks:
        x = blk(x)

    x = self.norm(x)

    return x

I also observed similar issues of dataparallel with types.MethodType has been reported here.