I am trying to use dataloaders in my code. I am implementing my code in aws sagemaker but for some reason when I use more than 0 num_workers for my dataloaders I get the error loading image [SSL] record layer failure (_ssl.c:2578)
or the image that is loaded is all zeros. I have tried sampler to distribute the data randomly and tried lazy boto3 client connection. The data is stored in an s3 bucket and dataloaders need to access and read the images from there. I have put the implementation of the dataloaders here:
class TrainQueryDataset(Dataset):
def __init__(self, bucket, target_prefix, dataset_type='train', crop_percentage=0, transform=None):
self.bucket = bucket
self.transform = transform
self.dataset_type = dataset_type
self.crop_percentage = crop_percentage
self.s3_client = None
# Get the list of files for the current dataset type (train/query)
self.files = self._load_data(bucket, os.path.join(target_prefix, dataset_type + '/'))
def _initialize_s3_client(self):
if self.s3_client is None:
# Lazily initialize the boto3 client
self.s3_client = boto3.client('s3')
def _load_data(self, Bucket, Prefix):
self._initialize_s3_client()
# print("Load Data")
paginator = self.s3_client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=Bucket, Prefix=Prefix)
list_img_paths = []
for page in pages:
for obj in page['Contents']:
if obj['Key'].endswith('.png'):
list_img_paths.append(obj['Key'])
return list_img_paths
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file_key = self.files[idx]
try:
self._initialize_s3_client()
# Fetch the image from S3
response = self.s3_client.get_object(Bucket=self.bucket, Key=file_key)
image_data = response['Body'].read()
image = Image.open(io.BytesIO(image_data))
if image is None:
print("image not read properly")
# Crop the images
if self.crop_percentage:
# If the images are in train dataset the bottom section needs to be cropped
# If the images are in query then the top section needs to be cropped
width, height = image.size
crop_section = height * self.crop_percentage // 100
if self.dataset_type == 'train':
image = image.crop((0, crop_section, width, height)) # Bottom section
else:
image = image.crop((0, 0, width, height - crop_section)) # Top section
# Transform the data
if self.transform:
image = self.transform(image)
# Extract label from filename (assuming format: 'originalname_label.png')
label_str = os.path.basename(file_key).split('_')[-1].split('.')[0]
label = int(label_str)
return image, label
except Exception as e:
print(f"Error loading image {file_key}: {e}")
return None, None
and the dataloaders code
class CustomDataModule(pl.LightningDataModule):
def __init__(self, config, bucket, target_prefix):
super().__init__()
self.config = copy.deepcopy(config)
self.bucket = bucket
self.target_prefix = target_prefix
self.train_dataset = None
self.val_dataset = None
self.query_dataset = None
self.sampler = None
def setup(self, stage=None):
# Initialize datasets based on the stage
if stage == 'fit' or stage is None:
self.train_dataset = TrainQueryDataset(
bucket=self.bucket,
target_prefix=self.target_prefix,
dataset_type='train',
crop_percentage=self.config['dataset']['crop_percentage'],
transform=augment_data(train=True, **self.config["transform"])
)
self.val_dataset = TrainQueryDataset(
bucket=self.bucket,
target_prefix=self.target_prefix,
dataset_type='train',
crop_percentage=self.config['dataset']['crop_percentage'],
transform=augment_data(train=False, **self.config["transform"])
)
self.query_dataset = TrainQueryDataset(
bucket=self.bucket,
target_prefix=self.target_prefix,
dataset_type='query',
crop_percentage=self.config['dataset']['crop_percentage'],
transform=augment_data(train=False, **self.config["transform"])
)
if self.config['dataloader']['sampler']:
self.Sampler_train = RandomSampler(self.train_dataset)
self.Sampler_val = RandomSampler(self.val_dataset)
self.Sampler_query = RandomSampler(self.query_dataset)
del self.config['dataloader']['sampler']
def train_dataloader(self):
return DataLoader(self.train_dataset, sampler=self.Sampler_train, **self.config['dataloader'])
def val_dataloader(self):
return DataLoader(self.val_dataset, sampler=self.Sampler_val, **self.config['dataloader'])
def test_dataloader(self):
return DataLoader(self.query_dataset, sampler=self.Sampler_query, **self.config['dataloader'])
any help would be appreciated.