Hello Guys,
I am experimenting with a pytorch example and I am facing different behaviors when I try to train the model using CPU and GPU. When I try to load the model to GPU an error to the code appears but when the same procedure executed to CPU the code runs perfectly. Here is the main part of the code that changes:
import model as model
basis= 'splines'
param=8
net = model.create_net(basis, param)
#net.to('cuda')
for n,p in net.named_parameters():
print(p.device,'',n)
Here is the model.py
import torch
from torch.nn import functional as F
from utility import splines
class Ennet(torch.nn.Module):
def __init__(self, enhancement_module):
super().__init__()
momentum = 0.01
self.c1 = torch.nn.Conv2d(3, 8, kernel_size=5, stride=4, padding=0)
self.c2 = torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0)
self.b2 = torch.nn.BatchNorm2d(16, momentum=momentum)
self.c3 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0)
self.b3 = torch.nn.BatchNorm2d(32, momentum=momentum)
self.c4 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0)
self.b4 = torch.nn.BatchNorm2d(64, momentum=momentum)
self.downsample = torch.nn.AvgPool2d(7, stride=1)
self.fc = torch.nn.Sequential(
torch.nn.Linear(64, 64),
torch.nn.ReLU(True),
torch.nn.Linear(64, enhancement_module.parameters_count)
)
self.emodule = enhancement_module
#print("enhancement_module iss___",enhancement_module)
def forward(self, image, applyto=None):
x = image
print("Enter main model forward ")
#print("Tararamammm",image.device)
if (image.size(2), image.size(3)) != (256, 256):
x = _bilinear(x, 256, 256)
x = x - 0.5
x = F.relu(self.c1(x))
x = self.b2(F.leaky_relu(self.c2(x)))
x = self.b3(F.leaky_relu(self.c3(x)))
x = self.b4(F.leaky_relu(self.c4(x)))
x = self.downsample(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
applyto = (image if applyto is None else applyto)
result = applyto + self.emodule(applyto, x)
if not self.training:
result = torch.clamp(result, 0, 1)
return result
class EnhancementModule(torch.nn.Module):
def __init__(self, parameters_count):
super().__init__()
self.parameters_count = parameters_count
print("Insert EnhancementModule" )
def forward(self, image, parameters):
print("Insert EnhancementModule forward" )
return image
class Splines(EnhancementModule):
def __init__(self, nodes):
super().__init__(nodes * 3)
print("enter Splines")
self.interpolator = splines.SplineInterpolator(nodes)
print("exit Splines")
def forward(self, image, parameters):
k = image.size(0) * 3
x = image.view(k, -1)
y = parameters.view(k, -1)
z = self.interpolator(y, x)
print("Enters Splines Forwrd")
return z.view_as(image)
def create_net(basis_name, basis_param):
return Ennet(BASIS[basis_name](basis_param))
BASIS = {
"splines" : Splines,
"poly": PolynomialBasis
}
Here is the splines.py
import torch
import numpy as np
class SplineInterpolator(torch.nn.Module):
def __init__(self, nodes, dtype=torch.float32):
super().__init__()
A = self._precalc(nodes)
self.register_buffer("A", torch.tensor(A, dtype=dtype))
print("Insert init__")
def _precalc(self, n):
print("Insert _precalc")
h = 1.0 / (n - 1)
mat = 4 * np.eye(n - 2)
np.fill_diagonal(mat[1:, :-1], 1)
np.fill_diagonal(mat[:-1, 1:], 1)
A = 6 * np.linalg.inv(mat) / (h ** 2)
z = np.zeros(n - 2)
A = np.vstack([z, A, z])
B = np.zeros([n - 2, n])
np.fill_diagonal(B, 1)
np.fill_diagonal(B[:, 1:], -2)
np.fill_diagonal(B[:, 2:], 1)
A = np.dot(A, B)
return A.T
def _coefficients(self, y):
n = self.A.size(1)
h = 1.0 / (n - 1)
M = torch.mm(y, self.A)
a = (M[:, 1:] - M[:, :-1]) / (6 * h)
b = M[:, :-1] / 2
c = (y[:, 1:] - y[:, :-1]) / h - (M[:, 1:] + 2 * M[:, :-1]) * (h / 6)
print("Insert coeff")
return (a, b, c, y[:, :-1])
def _apply(self, x, coeffs):
print("Insert _apply")
n = self.A.size(1)
xv = x.view(x.size(0), -1)
xi = torch.clamp(xv * (n - 1), 0, n - 2).long()
xf = xv - xi.float() / (n - 1)
a, b, c, d = (torch.gather(cc, 1, xi) for cc in coeffs)
z = d + c * xf + b * (xf ** 2) + a * (xf ** 3)
return z.view_as(x)
def forward(self, y, x):
print("Insert forward")
#print(self._coefficients(y))
return self._apply(x, self._coefficients(y))
These are the results when the model is on CPU
Insert EnhancementModule
enter Splines
Insert _precalc
Insert init__
exit Splines
cpu c1.weight
cpu c1.bias
cpu c2.weight
cpu c2.bias
cpu b2.weight
cpu b2.bias
cpu c3.weight
cpu c3.bias
cpu b3.weight
cpu b3.bias
cpu c4.weight
cpu c4.bias
cpu b4.weight
cpu b4.bias
cpu fc.0.weight
cpu fc.0.bias
cpu fc.2.weight
cpu fc.2.bias
These are the results when the .to(‘cuda’) command is used
Insert EnhancementModule
enter Splines
Insert precalc
Insert init_
exit Splines
File ~\anaconda3\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)
=====>>>>>>>>>>Error to net.to(‘cuda’) line<<<<<<<<=======
File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1152 in to
return self._apply(convert)
File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:802 in _apply
module._apply(fn)
File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:802 in _apply
module._apply(fn)
TypeError: SplineInterpolator._apply() missing 1 required positional argument: ‘coeffs’
Any sugesstions?