Hi!
I have a dataset wrap where patterns are organized in two contiguous torch.tensor.
The problem is I have to keep the order of indices
suppose dataset A with indices from 0 to 3 and dataset B with indices from 4 to 7
I need idx (as list)
idx_list = [3,6,2,4,7]
So not only interleaved between dataset but with indices in random order and order must be kept, i need as return value
tensor = [A(3), B(6), A(2), B(4), B(7)]
is there an efficient way to index, because the implementation is quite simple but, as shown, involve a for loop
i_patter_A = [value for value in idx_list if value < len_dataset_A] # 2.1) split list based on index (discriminate values from real or fake list)
i_patter_B = [value for value in idx_list if value >= len_dataset_A]
batch_pattern = torch.zeros(len(i_patter_A)+len(i_patter_B)
for i,curr_pattern in enumerate(idx_list):
if curr_pattern < len_dataset_A:
batch_pattern[i] = self.dataset[self.indices[curr_pattern]]
else:
batch_pattern[i] = self.get_generate_data(curr_pattern)
return batch_pattern
for me it’s a schoolbook-solution but does not take advantages of tensor indexing
many thanks
Eugenio