Hello from spain comunity.
I am working with this repository
deep-text-recognition-benchmark
I am using the TRAIN.PY file (to train) and i would like make a little data augmentation (random rotation between -45 and 45 degree) in the fly.
In the lane 31 you can see this
train_dataset = Batch_Balanced_Dataset(opt)
And line 145
while(True):
# train part
image_tensors, labels = train_dataset.get_batch()
image = image_tensors.to(device)
In Dataset.py, the method get_batch in Batch_Balanced_Dataset is something like this
def get_batch(self):
balanced_batch_images = []
balanced_batch_texts = []
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
try:
image, text = data_loader_iter.next()
balanced_batch_images.append(image)
balanced_batch_texts += text
except StopIteration:
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
image, text = self.dataloader_iter_list[i].next()
balanced_batch_images.append(image)
balanced_batch_texts += text
except ValueError:
pass
balanced_batch_images = torch.cat(balanced_batch_images, 0)
return balanced_batch_images, balanced_batch_texts
This are link to the code
train.py
except:
pass
start_time = time.time()
best_accuracy = -1
best_norm_ED = -1
iteration = start_iter
while(True):
# train part
image_tensors, labels = train_dataset.get_batch()
image = image_tensors.to(device)
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
batch_size = image.size(0)
if 'CTC' in opt.Prediction:
preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
if opt.baiduCTC:
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size
dataset.py
Total_batch_size_log = f'{dashed_line}\n'
batch_size_sum = '+'.join(batch_size_list)
Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
Total_batch_size_log += f'{dashed_line}'
opt.batch_size = Total_batch_size
print(Total_batch_size_log)
log.write(Total_batch_size_log + '\n')
log.close()
def get_batch(self):
balanced_batch_images = []
balanced_batch_texts = []
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
try:
image, text = data_loader_iter.next()
balanced_batch_images.append(image)
balanced_batch_texts += text
except StopIteration:
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
Could you help me to make that random rotation in the fly? I am not familiarized with dataloader in pytorch yet
The usual way would be to pass a transformation to your Dataset
and apply it in the __getitem__
.
Based on the linked code, it seems that you are using different abstractions, where get_batch
is calling into a DataLoader
. You could either try to apply the transformation directly in the get_batch
method or add it to the LmdbDataset
and apply it in the __getitem__
. The data loading tutorial explains the common use case in more detail.
1 Like