import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import datasets, models, transforms
import os
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import unetr
from unetr import UNETR
from utils import (
load_checkpoint,
save_checkpoint,
get_loaders,
check_accuracy,
save_predictions_as_imgs,
)
import time
import copy
Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = “cuda” if torch.cuda.is_available() else “cpu”
BATCH_SIZE = 16
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL = False
DATASET_SIZE = {‘train’ : 6000, ‘val’ : 750, ‘test’ : 750}
TRAIN_IMG_DIR = “…/split_dataset_final/train/IMG”
TRAIN_MASK_DIR = “…/split_dataset_final/train/GT”
VAL_IMG_DIR = “…/split_dataset_final/val/IMG”
VAL_MASK_DIR = “…/split_dataset_final/val/GT”
TEST_IMG_DIR = “…/split_dataset_final/test/IMG”
TEST_MASK_DIR = “…/split_dataset_final/test/GT”
PATCH_SIZE = 16
def patchify(batch, patch_size):
“”"
Patchify the batch of images
Shape:
batch: (b, h, w, c)
output: (b, nh, nw, ph, pw, c)
"""
b, c, h, w = batch.shape
ph = patch_size
pw = patch_size
nh, nw = h // ph, w // pw
batch_patches = torch.reshape(batch, (b, c, nh, ph, nw, pw))
batch_patches = torch.permute(batch_patches, (0, 1, 2, 4, 3, 5))
b, c, nh, nw, ph, pw = batch_patches.shape
# Flattening the patches
batch = torch.permute(batch_patches, [0, 2, 3, 4, 5, 1])
batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])
batch = batch.to(DEVICE)
return batch
def train_model(model, dataloaders, criterion, optimizer, scheduler, scaler, num_epochs):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_optimizer_wts = copy.deepcopy(optimizer.state_dict())
best_acc = 0.0
best_dice_score = 0
best_loss = 100
best_epoch = -1
trigger = 0
tolerance = 3
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
loop = tqdm(dataloaders[phase])
else:
model.eval() # Set model to evaluate mode
loop = dataloaders[phase]
running_loss = 0.0
# Iterate over data.
for inputs, labels in loop:
inputs = patchify(inputs, PATCH_SIZE)
print(inputs.shape)
inputs = inputs.to(DEVICE)
labels = labels.float().unsqueeze(1).to(DEVICE)
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
with torch.cuda.amp.autocast():
predictions = model(inputs)
loss = criterion(predictions, labels)
# backward + optimize only if in training phase
if phase == 'train':
# forward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# statistics
running_loss += loss.item() * inputs.size(0)
if phase == 'train':
scheduler.step()
loop.set_postfix(loss=loss.item())
# deep copy the model
epoch_loss = running_loss / DATASET_SIZE[phase]
if phase == 'val':
dice_score = check_accuracy(dataloaders[phase], model, device=DEVICE)
if dice_score > best_dice_score:
best_dice_score = dice_score
best_model_wts = copy.deepcopy(model.state_dict())
best_optimizer_wts = copy.deepcopy(optimizer.state_dict())
best_epoch = epoch
else:
trigger += 1
print('Trigger Added!: ', trigger)
if trigger == tolerance:
print(f'Early Stopped at epoch: {epoch}')
break
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val dice score: {:4f} at epoch: {}'.format(best_dice_score, best_epoch))
# save model
PATH = "unet_model_checkpoint.pt"
torch.save({
'epoch': best_epoch,
'model_state_dict': best_model_wts,
'optimizer_state_dict': best_optimizer_wts,
'dice': best_dice_score,
}, PATH)
# load best model weights
model.load_state_dict(best_model_wts)
return model
def main():
print('Using ', DEVICE)
train_transform = A.Compose(
[
# A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
# A.Rotate(limit=35, p=1.0),
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
config = {}
config["image_size"] = 256
config["num_layers"] = 12
config["hidden_dim"] = 768
config["mlp_dim"] = 3072
config["num_heads"] = 12
config["dropout_rate"] = 0.1
config["num_patches"] = 256
config["patch_size"] = PATCH_SIZE
config["num_channels"] = 3
model = UNETR(config).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
train_loader, val_loader, test_loader = get_loaders(
TRAIN_IMG_DIR,
TRAIN_MASK_DIR,
VAL_IMG_DIR,
VAL_MASK_DIR,
TEST_IMG_DIR,
TEST_MASK_DIR,
BATCH_SIZE,
train_transform,
val_transforms,
NUM_WORKERS,
PIN_MEMORY,
)
dataloaders = {'train': train_loader, 'val': val_loader}
if LOAD_MODEL:
print('Loading Model')
checkpoint = torch.load("unet_model_checkpoint.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print('Training Model')
scaler = torch.cuda.amp.GradScaler()
train_model(model, dataloaders, loss_fn, optimizer, step_lr_scheduler, scaler, NUM_EPOCHS)
# save_predictions_as_imgs(
# dataloaders['val'], model, folder="saved_images/", device=DEVICE
# )
if name == “main”:
main()
This is my code for implementing ViT from the scratch. I added patchify function but I have an error of
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Does anyone know why this happen?