Hi there,
I just started using PyTorch and want to build a patch classifier for breast mammography. Thing is, my image patches are in range from [0, 65535] and I just found out that ToTensor() operation is treating my images as they are 8-bit. Here is the code I am currently using to load my dataset:
data_transforms = {
'train': transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]),
'val': transforms.Compose([
transforms.ToTensor(),
]),
'test': transforms.Compose([
transforms.ToTensor(),
])
}
image_datasets = {x: datasets.ImageFolder(os.path.join(raw_images_root_dir, x), data_transforms[x])
for x in ['train', 'val', 'test']}
dataloaders = {x : torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=multiprocessing.cpu_count())
for x in ['train', 'val', 'test']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
When I try to train my model I am gettig REALLY BAD RESULTS. Basically 50% accuracy throughout entire training. Here is my training setup:
criterion = nn.CrossEntropyLoss()
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
model = train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, num_epochs=11)
Here is my training info:
[INFO]: Epoch 0/10
[INFO]: Epoch [0/10], Step [0/134], Loss: 0.6922
[INFO]: Epoch [0/10], Step [20/134], Loss: 0.7699
[INFO]: Epoch [0/10], Step [40/134], Loss: 0.6767
[INFO]: Epoch [0/10], Step [60/134], Loss: 0.7273
[INFO]: Epoch [0/10], Step [80/134], Loss: 0.7482
[INFO]: Epoch [0/10], Step [100/134], Loss: 0.8035
[INFO]: Epoch [0/10], Step [120/134], Loss: 0.7248
[INFO]: train accuracy: 0.5035
[INFO]: train loss: 0.8703
[INFO]: Epoch [0/10], Step [0/34], Loss: 30.9444
[INFO]: Epoch [0/10], Step [20/34], Loss: 38.3757
[INFO]: val accuracy: 0.5000
[INFO]: val loss: 28.2381
[INFO]: Epoch 1/10
[INFO]: Epoch [1/10], Step [0/134], Loss: 0.6472
[INFO]: Epoch [1/10], Step [20/134], Loss: 0.6883
[INFO]: Epoch [1/10], Step [40/134], Loss: 0.6616
[INFO]: Epoch [1/10], Step [60/134], Loss: 0.6833
[INFO]: Epoch [1/10], Step [80/134], Loss: 0.6401
[INFO]: Epoch [1/10], Step [100/134], Loss: 0.6897
[INFO]: Epoch [1/10], Step [120/134], Loss: 0.6746
[INFO]: train accuracy: 0.5091
[INFO]: train loss: 0.7010
[INFO]: Epoch [1/10], Step [0/34], Loss: 0.9954
[INFO]: Epoch [1/10], Step [20/34], Loss: 0.7296
[INFO]: val accuracy: 0.5000
[INFO]: val loss: 0.7585
[INFO]: Epoch 2/10
[INFO]: Epoch [2/10], Step [0/134], Loss: 0.7247
[INFO]: Epoch [2/10], Step [20/134], Loss: 0.7023
[INFO]: Epoch [2/10], Step [40/134], Loss: 0.6870
[INFO]: Epoch [2/10], Step [60/134], Loss: 0.6869
[INFO]: Epoch [2/10], Step [80/134], Loss: 0.6935
[INFO]: Epoch [2/10], Step [100/134], Loss: 0.7037
[INFO]: Epoch [2/10], Step [120/134], Loss: 0.6893
[INFO]: train accuracy: 0.5119
[INFO]: train loss: 0.6927
[INFO]: Epoch [2/10], Step [0/34], Loss: 0.6981
[INFO]: Epoch [2/10], Step [20/34], Loss: 0.7100
[INFO]: val accuracy: 0.5000
[INFO]: val loss: 0.6977
[INFO]: Epoch 3/10
[INFO]: Epoch [3/10], Step [0/134], Loss: 0.7029
[INFO]: Epoch [3/10], Step [20/134], Loss: 0.6941
[INFO]: Epoch [3/10], Step [40/134], Loss: 0.7016
[INFO]: Epoch [3/10], Step [60/134], Loss: 0.6886
[INFO]: Epoch [3/10], Step [80/134], Loss: 0.7015
[INFO]: Epoch [3/10], Step [100/134], Loss: 0.6877
[INFO]: Epoch [3/10], Step [120/134], Loss: 0.6901
[INFO]: train accuracy: 0.5021
[INFO]: train loss: 0.6919
[INFO]: Epoch [3/10], Step [0/34], Loss: 0.6654
[INFO]: Epoch [3/10], Step [20/34], Loss: 0.7031
[INFO]: val accuracy: 0.5000
[INFO]: val loss: 0.6976
[INFO]: Epoch 4/10
[INFO]: Epoch [4/10], Step [0/134], Loss: 0.6859
I guess this is due to the fact that my images are NOT loaded correctly. When I inspect one of the images with following code:
dataiter = iter(dataloaders['train'])
images, labels = dataiter.next()
print(images[0].min())
print(images[0].max())
I get following output:
tensor(1.)
tensor(1.)
How do I load my images correctly, taking into the fact that pixels are between [0, 65535] (16-bit).
Thank you!