weird thing is it works in the start, there’s been some normal results, then after a few batches it won’t work again
Outputs are like:
Epoch 1
loss: 2.291789 [ 128/40000]
loss: 2.298886 [12928/40000]
loss: 2.304425 [25728/40000]
loss: 2.292490 [38528/40000]
Traceback (most recent call last):
File “D:\Projects\project1_model0\project1\main.py”, line 194, in
main()
File “D:\Projects\project1_model0\project1\main.py”, line 190, in main
trainer.Train_session(name)
File “D:\Projects\project1_model0\project1\main.py”, line 157, in Train_session
self.test(“train”)
File “D:\Projects\project1_model0\project1\main.py”, line 104, in test
for batch, (img, label) in enumerate(self.test_loader, 0):
File “C:\Users\yxf\AppData\Roaming\Python\Python310\site-packages\torch\utils\data\dataloader.py”, line 633, in next
data = self._next_data()
File “C:\Users\yxf\AppData\Roaming\Python\Python310\site-packages\torch\utils\data\dataloader.py”, line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File “C:\Users\yxf\AppData\Roaming\Python\Python310\site-packages\torch\utils\data_utils\fetch.py”, line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “C:\Users\yxf\AppData\Roaming\Python\Python310\site-packages\torch\utils\data_utils\fetch.py”, line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “D:\Projects\project1_model0\project1\dataset.py”, line 26, in getitem
img = self.data[index]
KeyError: 1
I used a customized dataset:
import torch
from torch.utils.data import Dataset, DataLoader
class my_dataset(Dataset):
def init(self, name):
if name == “train”:
self.path = “…/data_set/train.pt”
self.data = torch.load(self.path, map_location=“cpu”)[“data”]
self.labels = torch.load(self.path, map_location=“cpu”)[“labels”]
elif name == "valid":
self.path = "../data_set/valid.pt"
self.data = torch.load(self.path, map_location="cpu")
self.labels = torch.load(self.path, map_location="cpu")["labels"]
elif name == "test":
self.path = "../data_set/test.pt"
self.data = torch.load(self.path, map_location="cpu")
self.labels = torch.load(self.path, map_location="cpu")["labels"]
else:
raise NotImplementedError
def __getitem__(self, index):
img = self.data[index]
label = self.labels[index] # (3, 28, 28), (1)
return img, label
def __len__(self):
return len(self.data)