Hello,
I am currently creating a custom module inside an existing CNN using pytorch. I am doing this as part of my schools research, and so I have access to a super computer with multiple GPU devices. When I train my model, after running through the first validation set I run into the following error:
“RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)”
Now, this does not happen when I skip the convolution in my module and just put some dummy calculation in. This error also doesn’t raise when I use a single gpu or a cpu. Here is my custom module. The convolution takes place in the ‘conv_gauss’ function at the very bottom:
class DivisiveNormBlock(nn.Module):
def __init__(self, channel_num = 512, size = 56, ksize = 4):
super().__init__()
self.channel_num = channel_num
self.size = size
self.ksize = ksize
scale = 90 # Random scale factor I've been playing with.
self.theta = torch.nn.Parameter(scale * torch.abs(torch.randn(self.channel_num, self.channel_num, device="cuda",
requires_grad=True))) # 512 thetas for a channel, 512 channels, same goes for...
self.p = torch.nn.Parameter(
scale * torch.abs(torch.randn(self.channel_num, self.channel_num, device="cuda", requires_grad=True)))
self.sig = torch.nn.Parameter(
scale * torch.abs(torch.randn(self.channel_num, self.channel_num, device="cuda", requires_grad=True)))
self.a = torch.nn.Parameter(
scale * torch.abs(torch.randn(self.channel_num, self.channel_num, device="cuda", requires_grad=True)))
self.nI = torch.nn.Parameter(
torch.abs(torch.randn(self.channel_num, self.channel_num, device="cuda", requires_grad=True)))
self.nU = torch.nn.Parameter(torch.abs(torch.randn(self.channel_num, device="cuda", requires_grad=True)))
self.bias = torch.nn.Parameter(torch.abs(torch.randn(self.channel_num, device="cuda", requires_grad=True)))
self.gaussian_bank = torch.zeros(self.channel_num, self.channel_num, self.ksize * 2+ 1, self.ksize * 2+ 1,
device="cuda")
self.x = torch.linspace(-self.ksize, self.ksize, self.ksize * 2 + 1, device="cuda")
self.y = torch.linspace(-self.ksize, self.ksize, self.ksize * 2 + 1, device="cuda")
self.xv, self.yv = torch.meshgrid(self.x, self.y)
for i in range(self.channel_num):
for u in range(self.channel_num):
self.gaussian_bank[i, u, :, :] = self.get_gaussian(i, u)
p = int((self.ksize * 2) / 2)
conv_kernel_size = self.ksize * 2+ 1
self.conv = nn.Conv2d(self.channel_num, self.channel_num, padding=p, stride=1,
kernel_size=conv_kernel_size, bias=False)
def get_gaussian(self, cc, oc): #
xrot = self.xv * torch.cos(self.theta[cc, oc]) + self.yv * torch.sin(self.theta[cc, oc])
yrot = -self.xv * torch.sin(self.theta[cc, oc]) + self.yv * torch.cos(self.theta[cc, oc])
g_kernel = (self.a[cc, oc]) + self.yv * torch.cos(self.theta[cc, oc])
g_kernel = (self.a[cc, oc] / \
(2 * torch.pi * self.p[cc, oc] * self.sig[cc, oc])) * \
torch.exp(-0.5 * ((((xrot) ** 2) / self.p[cc, oc] ** 2) + ((yrot) ** 2) / self.sig[cc, oc] ** 2))
return g_kernel
def forward(self, x):
x_test = self.dn_f(x)
return x
def dn_f(self, x):
batch_size = x.shape[0]
under_sum = torch.zeros((self.channel_num, self.size, self.size), device="cuda")
normalized_channels = torch.zeros((batch_size, self.channel_num, self.size, self.size), device="cuda")
for b in tqdm(range(batch_size)):
for i in range(self.channel_num):
for u in range(self.channel_num):
under_sum[u] = self.conv_gauss(torch.pow(x[b, i], self.nI[i, u]), self.gaussian_bank[i, u])
#under_sum[u] = under_sum[u]
normalized_channels[b, i] = torch.pow(x[b, i], self.nU[i]) / (
torch.pow(self.bias[i], self.nU[i]) + torch.sum(under_sum, 0))
return normalized_channels
def conv_gauss(self, x_conv, gauss_conv):
x_conv = torch.reshape(x_conv, (1, 1, self.size, self.size))
gauss_conv = torch.reshape(gauss_conv, (1, 1, self.ksize * 2+ 1, self.ksize * 2+ 1))
p = int((self.ksize*2)/2)
self.conv.weight = nn.Parameter(gauss_conv)
output = self.conv(x_conv)
output = torch.reshape(output, (self.size, self.size))
return output
For some extra context, I also did tried this with F.conv2d in the function instead, but to no avail.