I just defined a WeightedRandomSampler like that:
classesid = [0,1,2,3,4,5,6,7]
for i in range(9):
print(file_seg[i])
y = np.loadtxt(file_seg[i])
class_sample_count = np.array(
[len(np.where(y == t)[0]) for t in classesid]
)
sum_class_sample_count = sum_class_sample_count + class_sample_count
if np.min(sum_class_sample_count) == 0 :
sum_class_sample_count = 1 +sum_class_sample_count
print(sum_class_sample_count)
weight = (1. / sum_class_sample_count)
samples_weight = np.array([weight[t] for t in classesid])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
trainsampler = WeightedRandomSampler(samples_weight, 1)
def merge(tbl):
xl_=[]
xf_=[]
y_=[]
nPoints_=[]
np_random=np.random.RandomState([x[-1] for x in tbl])
for _, xl, y, idx in tbl:
m=np.eye(3,dtype='float32')
m[0,0]*=np_random.randint(0,2)*2-1
m=np.dot(m,np.linalg.qr(np_random.randn(3,3))[0])
xl=np.dot(xl,m)
xl+=np_random.uniform(-1,1,(1,3)).astype('float32')
xl=np.floor(resolution*(4+xl)).astype('int64')
xf=np.ones((xl.shape[0],1)).astype('float32')
xl_.append(xl)
xf_.append(xf)
y_.append(y)
nPoints_.append(y.shape[0])
xl_=[np.hstack([x,idx*np.ones((x.shape[0],1),dtype='int64')]) for idx,x in enumerate(xl_)]
return {'x': [torch.from_numpy(np.vstack(xl_)),torch.from_numpy(np.vstack(xf_))],
'y': torch.from_numpy(np.hstack(y_)),
'xf': [x[0] for x in tbl],
'nPoints': nPoints_}
return torch.utils.data.DataLoader(d,batch_size=batchSize, collate_fn=merge, num_workers=10, sampler=trainsampler)
Even if I set the size of WeightedRandomSampler num_sample 1, I still met the error~ It’s so strange.
File "fully_convolutional.py", line 88, in <module>
for batch in trainIterator:
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 336, in __next__
return self._process_next_batch(batch)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 357, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
IndexError: list index out of range
Maybe in dataloader collate_fn conflict with sampler? Thanks for any advices!