I’m trying to load a model and run it on GPU nodes. The code snippet is:
import torch
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# -- embedding params
self.cn1 = nn.Conv2d(1, 256, kernel_size=3, stride=2)
self.cn2 = nn.Conv2d(256, 256, kernel_size=3, stride=1)
self.cn3 = nn.Conv2d(256, 256, kernel_size=3, stride=2)
self.cn4 = nn.Conv2d(256, 256, kernel_size=3, stride=1)
self.cn5 = nn.Conv2d(256, 256, kernel_size=3, stride=2)
self.cn6 = nn.Conv2d(256, 256, kernel_size=3, stride=2)
# -- prediction params
self.fc1 = nn.Linear(2304, 1700)
self.fc2 = nn.Linear(1700, 1200)
self.fc3 = nn.Linear(1200, 964)
class Train:
def __init__(self):
self.path_pretrained = './model_stat.pth'
self.model = self.load_model()
def load_model(self):
# -- init model
model = MyModel()
old_model = torch.load(self.path_pretrained)
for old_key in old_model:
dict(model.named_parameters())[old_key].data = old_model[old_key]
return model.to('cuda')
my_train = Train()
I run this with no issue on a CPU (with .to('cpu')
). However, when I run it on GPU, I get the following error:
Traceback (most recent call last):
File “test.py”, line 41, in
my_train = Train()
File “test.py”, line 28, in init
self.model = self.load_model()
File “test.py”, line 38, in load_model
return model.to(‘cuda’)
File “/…/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 907, in to
return self._apply(convert)
File “/…/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 578, in _apply
module._apply(fn)
File “/…/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 601, in _apply
param_applied = fn(param)
File “/…/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 905, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
I was wondering how to resolve this issue. The model can be found here.