For a model takes custom type as input, can I customize scatter function for DataParallel?

Mymodel has too much inputs so I’ve created a class MyInput to handle it. The class has many list of tensors as property, and each tensor in the lists is a (batched) input.

import torch as pt

class MyInput: # batched input
    device: pt.device # device of all the Tensors below
    x: list[pt.FloatTensor]
    mask: list[pt.BoolTensor]
    idx: list[pt.IntTensor]
    # each item in the list is one of the inputs of MyModel, 
    # with batch at first dim
    
    def __init__(self, 
                x: list[pt.FloatTensor], 
                mask: list[pt.BoolTensor], 
                idx: list[pt.IntTensor]
            ) -> None:
        self.x, self.mask, self.idx = x, mask, idx
    # other things of MyInput
    ...

class MyModel(pt.nn.Module): 
    # Implementation of MyModel
    ...

    def forward(self, batch: MyInput) -> pt.FloatTensor: 
        # forward function with batch as input, e.g.
        a = self.a_certain_layer(batch.x[3])
        b = a[batch.mask[2]]
        c = b[:, batch.idx[3]]
        return c

Obviously, Dataparallel cannot process MyInput by default, so I plan to write a custom scatter function for MyInput, like this:

def scatter_MyInput(all_batch: MyInput) -> list[MyInput]: 
    x_s = default_scatter(all_batch.x)
    # default_scatter is the default scatter function of DataParallel
    m_s = default_scatter(all_batch.mask)
    i_s = default_scatter(all_batch.idx)
    return [MyInput(*props) for props in zip(x_s, m_s, i_s)]

So,

  1. How can I tell DataParallel to use custom scatter_MyInput instead of default scatter function?
  2. Where can I import default_scatter?
  3. Does default_scatter work for listed batched tensor? Or do I have to scatter the item in list with default_scatter one by one?