Mymodel
has too much inputs so I’ve created a class MyInput
to handle it. The class has many list of tensors as property, and each tensor in the lists is a (batched) input.
import torch as pt
class MyInput: # batched input
device: pt.device # device of all the Tensors below
x: list[pt.FloatTensor]
mask: list[pt.BoolTensor]
idx: list[pt.IntTensor]
# each item in the list is one of the inputs of MyModel,
# with batch at first dim
def __init__(self,
x: list[pt.FloatTensor],
mask: list[pt.BoolTensor],
idx: list[pt.IntTensor]
) -> None:
self.x, self.mask, self.idx = x, mask, idx
# other things of MyInput
...
class MyModel(pt.nn.Module):
# Implementation of MyModel
...
def forward(self, batch: MyInput) -> pt.FloatTensor:
# forward function with batch as input, e.g.
a = self.a_certain_layer(batch.x[3])
b = a[batch.mask[2]]
c = b[:, batch.idx[3]]
return c
Obviously, Dataparallel
cannot process MyInput
by default, so I plan to write a custom scatter function for MyInput
, like this:
def scatter_MyInput(all_batch: MyInput) -> list[MyInput]:
x_s = default_scatter(all_batch.x)
# default_scatter is the default scatter function of DataParallel
m_s = default_scatter(all_batch.mask)
i_s = default_scatter(all_batch.idx)
return [MyInput(*props) for props in zip(x_s, m_s, i_s)]
So,
- How can I tell
DataParallel
to use customscatter_MyInput
instead of default scatter function? - Where can I import
default_scatter
? - Does
default_scatter
work for listed batched tensor? Or do I have to scatter the item in list withdefault_scatter
one by one?