I am trying to get a batch from a TensorDataset
with __getitem__
function defined as follows:
ef __getitem__(self, index):
... some code here ...
return {"img":img, "mask":mask}
If I call train_set[0]
I get the first sample as dict, which is what I would expect.
However, if I run,
ind = [1,2,3]
batch = train_set[ind]
I get the error:
*** TypeError: list indices must be integers or slices, not list
although ind
is a list of integers.
I was hoping that something like train_set[ind]
would extract a dict with keys “img” and “mask” where their values are the 3 samples corresponding to ind
.
Is it possible to achieve this ?
If this can’t be done currently, would it be useful to add a function get_batch(ind, n_workers=2)
to the Dataset
class ?
Thanks!