i want them as an input to my cycleGan, where we have data from domain A and domain B. It is used for mapping from domain A to B . So my inputs are both ata time to cycleGan
This error points to different lengths of both input tensors. TensorDataset would return samples from both inputs using the same index internally.
If one tensor contains more samples (dim0 is used for indexing), this error is thrown.
You could e.g. crop the larger tensor to the same length as the smaller one or duplicate the smaller one.
If you want to apply some more complicated sampling strategy, I would recommend to write a custom Dataset and create the pairs in __getitem__.
Here is a small example to reproduce this error and how to slice the larger tensor:
# Works, since a and b have the same length
a = torch.arange(10).view(-1, 1)
b = torch.arange(10).view(-1, 1)
dataset = TensorDataset(a, b)
loader = DataLoader(
dataset,
batch_size=5
)
for idx, (data1, data2) in enumerate(loader):
print('Idx ', idx)
print(data1)
print(data2)
# Use different lengths
a = torch.arange(10).view(-1, 1)
b = torch.arange(20).view(-1, 1)
dataset = TensorDataset(a, b) # fails
# Slice b to have the same length
dataset = TensorDataset(a, b[:a.size(0)])
loader = DataLoader(
dataset,
batch_size=5
)
for idx, (data1, data2) in enumerate(loader):
print('Idx ', idx)
print(data1)
print(data2)
Note that this approach might not be the best for your use case, so you should apply your method to yield corresponding pairs of both input tensors.