My use case is training a segmentation model using a pretrained/built model. My training loop looks like this:
model.to(device)
if is_distributed and use_cuda:
# multi-machine multi-gpu case
model = torch.nn.parallel.DistributedDataParallel(model)
else:
# single-machine multi-gpu case or single-machine or multi-machine cpu case
model = torch.nn.DataParallel(model)
torchinfo.summary(model)
model.to(device)
model.train()
tv_loss_fn = tversky_focal_loss
optimizer = timm.optim.create_optimizer_v2(model, optimizer_name='adamw', learning_rate=args.lr,
momentum=args.momentum, amsgrad=True, weight_decay=1e-2)
lookahead_optimizer = timm.optim.Lookahead(optimizer, alpha=0.5, k=6)
print(lookahead_optimizer)
lookahead_optimizer.zero_grad()
torch.backends.cudnn.benchmark = True
scaler = torch.cuda.amp.GradScaler()
for epoch in range(args.epochs):
running_loss = 0.0
c_loss = 0.0
for i, input_data in enumerate(train_dl, 0):
# print(i)
# print(torch.cuda.mem_get_info(torch.cuda.current_device()))
inputs, labels = input_data
inputs = inputs.to(device)
with torch.cuda.amp.autocast():
outputs = model(inputs)
outputs_cpu = outputs.cpu()
labels_cpu = labels.cpu()
loss = tv_loss_fn(outputs_cpu[:,1,:,:].unsqueeze(1), labels_cpu[:,1,:,:].unsqueeze(1)).to(device)
scaler.scale(loss).backward()
if (i + 1) % 2 == 0 or (i + 1) == len(train_dl):
scaler.step(lookahead_optimizer)
scaler.update()
lookahead_optimizer.zero_grad(set_to_none=True)
running_loss += float(loss)
c_loss = loss.item()
if i % args.log_interval == 0:
logger.info(
"Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
epoch,
i * len(inputs),
len(train_dl.sampler),
100.0 * i / len(train_dl),
loss.item(),
)
)
logger.debug(f"Epoch {epoch} finished")
lookahead_optimizer.sync_lookahead()
and the traceback is:
Traceback (most recent call last):
File "D:\Programs\Programming\seg5\SegTest3-local.py", line 457, in <module>
_train(args=parser.parse_args())
File "D:\Programs\Programming\seg5\SegTest3-local.py", line 309, in _train
outputs = model(inputs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\parallel\data_parallel.py", line 169, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "D:\Programs\Programming\seg5\SegTest3-local.py", line 199, in forward
x = self.DLV3_model(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\segmentation_models_pytorch\base\model.py", line 29, in forward
features = self.encoder(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\segmentation_models_pytorch\encoders\timm_universal.py", line 30, in forward
features = self.model(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\timm\models\features.py", line 282, in forward
x = module(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\timm\models\xception_aligned.py", line 108, in forward
x = self.stack(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\container.py", line 204, in forward
input = module(input)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\timm\models\xception_aligned.py", line 74, in forward
x = self.conv_pw(x)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "C:\Users\camma\.virtualenvs\seg5\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.38 GiB (GPU 0; 12.00 GiB total capacity; 1007.72 MiB already allocated; 9.00 GiB free; 1020.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Process finished with exit code 1
I’m using windows 10, pytorch 1.13.0+cu117 and an NVIDIA RTX 3060 to train on.
Something to note is that if I manually reset the GPU using powershell it will work for a while and it only breaks intermittently.
EDIT: Turns out that isn’t working anymore (or perhaps it never worked?) I now have to totally restart my system when this problems starts occurring for it to fix.
Hope this helps elucidate things.