Hello all,
I was recently training a model and due to less memory, I shifted to mixed-precision training, once with the Apex library and other with the torch amp module. I am using PyTorch 1.6.0. I realized that Apex uses way less memory. Apex uses a maximum of 4GB whereas using amp, the model hardly fits in the GPU using around 7.8GB. I am concerned whether is the training correct/ equivalent when using the two different packages.
My code for training with Nvidia’s apex library is as follows -
model = fasterrcnn_resnet50_fpn(pretrained=False,
num_classes=9,
pretrained_backbone=False, trainable_backbone_layers=5)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4, weight_decay=0.0001)
model, optimizer = amp.initialize(models=model, optimizers=optimizer,
opt_level='O2', keep_batchnorm_fp32=True)
for epoch in range(650):
loss_per_epoch = 0
for image, target in zip(images, targets):
optimizer.zero_grad()
target['boxes'] = target['boxes'].cuda()
target['labels'] = target['labels'].cuda()
image = image.cuda()
net_loss = 0
image_list = []
image_list.append(image)
target_list = []
target_list.append(target)
loss = model(image_list, target_list)
for i in loss.values():
net_loss+=i
with amp.scale_loss(net_loss, optimizer) as scaled_loss:
scaled_loss.backward()
loss_per_epoch+=scaled_loss
if not isfinite(scaled_loss):
raise RuntimeError("Nan loss")
optimizer.step()
and using torch amp module -
#when using amp I can only have 1 trainable backbone layer.
model = fasterrcnn_resnet50_fpn(pretrained=False,
num_classes=9,
pretrained_backbone=False, trainable_backbone_layers=1)
for epoch in range(650):
loss_per_epoch = 0
for image, target in zip(images, targets):
optimizer.zero_grad()
net_loss=0
target['boxes'] = target['boxes'].cuda()
target['labels'] = target['labels'].cuda()
with autocast():
loss = model(image_list, target_list)
for i in loss.values():
net_loss+=i
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Is there a flaw in one of the training paradigms which leads to almost half of the memory usage and could affect the results? My batch size is 1 in both cases.
TIA