class TrainDataset(torch.utils.data.Dataset):
...
def __getitem__(self, index):
...
# suppose we return this image
img = torch.randn(3, 240, 240)
channel_tmp = 10
other_info = torch.randn(channel_tmp , 240, 240)
return img, other_info
However, here, in my case, I need something like this:
class TrainDataset(torch.utils.data.Dataset):
...
def __getitem__(self, index):
...
# suppose we return this image
img = torch.randn(3, 240, 240)
channel_tmp = np.random.randint(5, 10)
other_info = torch.randn(channel_tmp , 240, 240)
return img, other_info
Of course, this snippet does not work, as this line of code return torch.stack(batch, 0, out=out)
in torch/utils/data/_utils/collate.py
.
Edit:
-
The
channel_tmp
varies a lot (from very small to a very large number), so if always set thechannel_tmp
the largest possible number and when the batchsize is 4 or 8, it will be OOM. -
In fact, I want something like this, for this specific variable
other_info
, I need data from each thread can be stacked on thechannel_tmp
dimension.`
Is there any workaround to this?