Hello i am trying to run training for object detection by using ssdlitemobilenetv3 model which is available on torchvision library. I am using a dataset of 30000 images with a batch size of 16.
The problem is that when i look to the GPU performance on task manager on windows it shows a very low gpu utilization (3GB). I have an NVIDIA RTX A6000 with 48 GB.
I also tried to increase the bacth size to 256 images and the memory usage is higher (14GB) which in my opinion is too low for the particular case.
I tried to perform training on another model (UNet) which i used for another project and, using the same batch size (16), the memory used is 20 GB. So i concluded there is an error in my code, which i was not able to find by myself.
Can you help me with this issue?
Below is my code:
import torch
from torch.utils.data import DataLoader
from torchvision.models import MobileNet_V3_Large_Weights
import time, datetime
from tqdm import tqdm
from torchvision.utils import draw_bounding_boxes
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
import training.transforms as T
from training.utils import SSD_Dataset
def collate_fn(batch):
return tuple(zip(*batch))
def get_transform(train, img_size, num_channels = 1, hflip_prob=0.5):
if train:
#return DetectionPresetTrain(img_size, num_channels, hflip_prob=hflip_prob)
return T.Compose(
[
T.ClipBoxesToImage(img_size),
T.RandomHorizontalFlip(hflip_prob),
T.PILToTensor(num_channels),
T.ConvertImageDtype(torch.float),
T.Normalize()
]
)
else:
#return DetectionPresetEval(img_size, num_channels)
return T.Compose(
[
T.ClipBoxesToImage(img_size),
T.PILToTensor(num_channels),
T.ConvertImageDtype(torch.float),
T.Normalize()
]
)
def train_one_epoch(model, opt, dataset, dataloader, device, epoch, epochs):
model.train()
train_losses = 0
with tqdm(total=len(dataset), desc=f'Epoch {epoch+1}/{epochs}', unit='img') as pbar:
for _, (images, targets) in enumerate(dataloader):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
train_losses += loss_value
opt.zero_grad()
losses.backward()
opt.step()
pbar.update(len(images))
pbar.set_postfix(**{'loss (batch)': loss_value})
return train_losses
def evaluate(model, dataloader, dataset, device, epoch, epochs):
model.eval()
accuracy = []
with tqdm(total= len(dataset), desc= f'Validation Epoch {epoch+1}/{epochs}', unit='img', leave=False) as pbar:
for _, (images, _) in enumerate(dataloader):
images = list(image.to(device) for image in images)
with torch.no_grad():
loss_dict = model(images)
for i in range(len(images)):
scores = loss_dict[i]['scores']
accuracy.append(torch.mean(scores))
pbar.update(len(images))
pbar.set_postfix(**{'avg accuracy': torch.mean(torch.stack(accuracy))})
if (epoch+1)%50 == 0:
boxes = loss_dict[0]['boxes']
scores = loss_dict[0]['scores']
inf_image = draw_bounding_boxes(torch.tensor(images[0]*255, dtype=torch.uint8), boxes[scores > 0.5], width=1)
inf_img_pil = Image.fromarray(inf_image.permute(1,2,0).detach().numpy())
inf_img_pil.save('inf_img_%d.png' %(epoch+1))
return accuracy
def train_and_test_model(MP, i):
"""
Main function for training and testing models
Arguments:
MP: dict
parameters dictionary
i: int
iteration number (when iterating over hyper params)
"""
print('Loading Data')
img_train_path = MP['img_train_dir']
img_val_path = MP['img_val_dir']
train_anno_path = MP['train_anno_dir']
val_anno_path = MP['val_anno_dir']
bs = MP['bs']
PIN_MEMORY = MP['pin_memory']
DEVICE = MP['device']
num_channels = MP['num_channels']
img_size = MP['dim']
n_train = MP['n_train']
n_val = MP['n_dev']
epoch_save = MP['epoch_save']
train_dataset = SSD_Dataset(img_train_path, train_anno_path, n_train, transform=get_transform(True, img_size, num_channels))
val_dataset = SSD_Dataset(img_val_path, val_anno_path, n_val, transform=get_transform(False, img_size, num_channels))
print('Creating Dataloaders')
train_collate_fn = collate_fn
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=bs, collate_fn=train_collate_fn, pin_memory=PIN_MEMORY)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=bs, collate_fn=train_collate_fn, pin_memory=PIN_MEMORY)
print('Loading Model')
lr = MP['lr']
wd = MP['wd']
output_dir = MP['output_dir']
#model = ssdlite_v3.ssdlite320_mobilenet_v3_large(weights=None, num_classes=2, weights_backbone=MobileNet_V3_Large_Weights.DEFAULT, trainable_backbone_layers=3).to(DEVICE)
model = ssdlite320_mobilenet_v3_large(weights=None, num_classes=2,
weights_backbone=MobileNet_V3_Large_Weights.DEFAULT,
trainable_backbone_layers=3).to(DEVICE)
optimizer = MP['optimizer']
parameters = [p for p in model.parameters() if p.requires_grad]
print(parameters[0].device)
if optimizer == 'sgd':
opt = torch.optim.SGD(
parameters,
lr=lr,
momentum=0.9,
weight_decay=wd,
)
elif optimizer == 'adamw':
opt = torch.optim.AdamW(parameters, lr=lr, weight_decay=wd)
# callbacks
early_stopping = EarlyStopping(tolerance=7, verbose=True, path=output_dir)
print('Start Training')
start_epoch = MP['start_epoch']
epochs = MP['epochs']
H = {'train_loss': [], 'accuracy': []}
train_steps = len(train_dataset)// bs
#val_steps = len(val_dataset)// bs
start_time = time.time()
for epoch in range(start_epoch, epochs):
epoch_time = time.time()
train_loss = train_one_epoch(model, opt, train_dataset, train_dataloader, DEVICE, epoch, epochs)
accuracy = evaluate(model, val_dataloader, val_dataset, DEVICE, epoch, epochs)
train_loss_avg = train_loss / train_steps
#val_loss_avg = val_loss / val_steps
print('Train_Loss for epoch %d: %f' %(epoch+1, train_loss_avg))
H['train_loss'].append(train_loss_avg)
H['accuracy'].append(max(accuracy).cpu().detach().numpy())
# epoch_len = len(str(epochs))
# print_msg = (f'[{epoch+1:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
# f'train_loss: {train_loss_avg:.5f} ' +
# f'valid_loss: {val_loss_avg:.5f}')
# print(print_msg)
#early_stopping(max(accuracy).cpu().detach().numpy(), model)
# if early_stopping.early_stop:
# print('early stop at epoch:', i)
# break
time_elapsed = time.time() - epoch_time
print(f'Epoch Training completed in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
#print(f'Best accuracy: {early_stopping.val_loss_min:.6f}')
print(f'Epoch Best Classification Accuracy: {max(accuracy).cpu().detach().numpy():.4f}')
if (epoch+1) in epoch_save:
print('saving model at epoch %d' %(epoch+1))
torch.save(model.state_dict(), output_dir)
# load the last checkpoint with the best model
#model.load_state_dict(torch.load(output_dir))
torch.save(model.state_dict(), output_dir)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str} seconds")