Traceback (most recent call last):
File "EuroSAT_train.py", line 119, in <module>
main()
File "EuroSAT_train.py", line 76, in main
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/opt/conda/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
raise exception
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/voyager-volume/MAML-Pytorch/EuroSAT.py", line 156, in __getitem__
for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
File "/voyager-volume/MAML-Pytorch/EuroSAT.py", line 156, in <listcomp>
for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
KeyError: '5594'
The code
euro = EuroSAT('/voyager-volume/MAML-Pytorch/EuroSAT/',mode='train', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=10000, resize=args.imgsz)
def get_subset(indices, start, end):
return indices[start:start+end]
indices = torch.randperm(len(euro))
train_indices = get_subset(indices, 0, len(euro))
train_sampler = SubsetRandomSampler(train_indices)
euro_test = EuroSAT('/voyager-volume/MAML-Pytorch/EuroSAT/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)
for epoch in range(args.epoch//10000):
# fetch meta_batchsz num of episode each time
db = DataLoader(euro, args.task_num, sampler=train_sampler, shuffle=False, num_workers=1, pin_memory=True)
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 30 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0: # evaluation
db_test = DataLoader(euro_test, 1, shuffle=True, num_workers=1, pin_memory=True)
accs_all_test = []
for x_spt, y_spt, x_qry, y_qry in db_test:
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
accs_all_test.append(accs)
# [b, update_step+1]
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print('Test acc:', accs)