class myNormalize:
def __call__(self, data):
print("Normalization function")
img_normalized = data / 255.0 # Normalize to [0, 1]
return img_normalized
class myToTensor:
def __call__(self, data):
print("myToTensor function")
image = data
return torch.tensor(image).permute(2, 0, 1)
class myResize:
def __init__(self, size_h=256, size_w=256):
self.size_h = size_h
self.size_w = size_w
def __call__(self, data):
print("Resize function")
image = data
return T.functional.resize(image, (self.size_h, self.size_w))
class InferenceDataset():
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.imgs = os.listdir(data_path)
self.data = []
for img_name in self.imgs:
img_path = data_path + img_name
self.data.append(img_path)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
print(f"Fetching item at index {index}")
img_path = self.data[index]
print(f"Fetching item at path: {img_path}")
img = np.array(Image.open(img_path).convert('RGB'))
img = self.transform(img)
return img, img_path
# Define the test transformer (composed of myNormalize, myResize, and myToTensor)
test_transformer = transforms.Compose([
myNormalize(), # Normalize to [0, 1]
myResize(size_h=256, size_w=256), # Resize to desired dimensions
myToTensor(), # Convert to PyTorch tensors
])
path_to_test_data = "../images/"
# Create an instance of the test dataset with the specified transformer
testing_dataset = InferenceDataset(path_to_test_data, transform= test_transformer)
print("Dataset Length:", len(testing_dataset))
# Create the test loader using the test dataset
test_loader = DataLoader(
testing_dataset,
batch_size=1,
shuffle=False,
num_workers=1 # Adjust this value based on your system configuration
)
Your code works fine for me after removing the undefined data folder:
class InferenceDataset():
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.transform = transform
def __len__(self):
return 10
def __getitem__(self, index):
print(f"Fetching item at index {index}")
return torch.randn(3, 224, 224), torch.randint(0, 10, (1,))
You might want to derive your custom dataset from torch.utils.data.Dataset
and also check the len
of it to make sure valid images are found.
1 Like
I identified the error. It was not working because of num of workers set to 1 when I am using it in CPU. I set it to 0 and it worked.