RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 222

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import datasets, models, transforms
import os
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import unetr
from unetr import UNETR
from utils import (
load_checkpoint,
save_checkpoint,
get_loaders,
check_accuracy,
save_predictions_as_imgs,
)

import time
import copy

Hyperparameters etc.

LEARNING_RATE = 1e-4
DEVICE = “cuda” if torch.cuda.is_available() else “cpu”
BATCH_SIZE = 16
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL = False
DATASET_SIZE = {‘train’ : 6000, ‘val’ : 750, ‘test’ : 750}
TRAIN_IMG_DIR = “…/split_dataset_final/train/IMG”
TRAIN_MASK_DIR = “…/split_dataset_final/train/GT”
VAL_IMG_DIR = “…/split_dataset_final/val/IMG”
VAL_MASK_DIR = “…/split_dataset_final/val/GT”
TEST_IMG_DIR = “…/split_dataset_final/test/IMG”
TEST_MASK_DIR = “…/split_dataset_final/test/GT”
PATCH_SIZE = 16

def patchify(batch, patch_size):
“”"
Patchify the batch of images

Shape:
    batch: (b, h, w, c)
    output: (b, nh, nw, ph, pw, c)
"""
b, c, h, w = batch.shape
ph = patch_size
pw = patch_size
nh, nw = h // ph, w // pw

batch_patches = torch.reshape(batch, (b, c, nh, ph, nw, pw))
batch_patches = torch.permute(batch_patches, (0, 1, 2, 4, 3, 5))
b, c, nh, nw, ph, pw = batch_patches.shape

# Flattening the patches
batch = torch.permute(batch_patches, [0, 2, 3, 4, 5, 1])
batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])
batch = batch.to(DEVICE)

return batch

def train_model(model, dataloaders, criterion, optimizer, scheduler, scaler, num_epochs):
since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())
best_optimizer_wts = copy.deepcopy(optimizer.state_dict())
best_acc = 0.0
best_dice_score = 0
best_loss = 100
best_epoch = -1
trigger = 0
tolerance = 3

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch+1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
            loop = tqdm(dataloaders[phase])
        else:
            model.eval()   # Set model to evaluate mode
            loop = dataloaders[phase]

        running_loss = 0.0

        # Iterate over data.

        for inputs, labels in loop:
            inputs = patchify(inputs, PATCH_SIZE)
            print(inputs.shape)
            inputs = inputs.to(DEVICE)
            labels = labels.float().unsqueeze(1).to(DEVICE)

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                with torch.cuda.amp.autocast():
                    predictions = model(inputs)
                    loss = criterion(predictions, labels)

                # backward + optimize only if in training phase
                    if phase == 'train':
                        # forward
                        optimizer.zero_grad()
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

            # statistics
            running_loss += loss.item() * inputs.size(0)

        if phase == 'train':
            scheduler.step()
            loop.set_postfix(loss=loss.item())

        # deep copy the model
        epoch_loss = running_loss / DATASET_SIZE[phase]
        if phase == 'val':
            dice_score = check_accuracy(dataloaders[phase], model, device=DEVICE)
            if dice_score > best_dice_score:
                best_dice_score = dice_score
                best_model_wts = copy.deepcopy(model.state_dict())
                best_optimizer_wts = copy.deepcopy(optimizer.state_dict())
                best_epoch = epoch
            else:
                trigger += 1
                print('Trigger Added!: ', trigger)
                if trigger == tolerance:
                    print(f'Early Stopped at epoch: {epoch}')
                    break


    print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
print('Best val dice score: {:4f} at epoch: {}'.format(best_dice_score, best_epoch))

# save model
PATH = "unet_model_checkpoint.pt"
torch.save({
            'epoch': best_epoch,
            'model_state_dict': best_model_wts,
            'optimizer_state_dict': best_optimizer_wts,
            'dice': best_dice_score,
            }, PATH)
# load best model weights
model.load_state_dict(best_model_wts)
return model

def main():
print('Using ', DEVICE)
train_transform = A.Compose(
[
# A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
# A.Rotate(limit=35, p=1.0),
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)
config = {}
config["image_size"] = 256
config["num_layers"] = 12
config["hidden_dim"] = 768
config["mlp_dim"] = 3072
config["num_heads"] = 12
config["dropout_rate"] = 0.1
config["num_patches"] = 256
config["patch_size"] = PATCH_SIZE
config["num_channels"] = 3

model = UNETR(config).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


train_loader, val_loader, test_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    TEST_IMG_DIR,
    TEST_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

dataloaders = {'train': train_loader, 'val': val_loader}

if LOAD_MODEL:
    print('Loading Model')
    checkpoint = torch.load("unet_model_checkpoint.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    print('Training Model')
    scaler = torch.cuda.amp.GradScaler()
    train_model(model, dataloaders, loss_fn, optimizer, step_lr_scheduler, scaler, NUM_EPOCHS)

# save_predictions_as_imgs(
#         dataloaders['val'], model, folder="saved_images/", device=DEVICE
#     )

if name == “main”:
main()

This is my code for implementing ViT from the scratch. I added patchify function but I have an error of
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

Does anyone know why this happen?

Could you post the model definition by wrapping the code into three backticks for proper formattig, please?