I encountered a strange issue and the simplified version of my codes are as below:
import torch
import torch.nn as nn
import math
class MyLinear(nn.Module):
def __init__(self, nin, nout):
super(MyLinear, self).__init__()
self.nout = nout
self.nin = nin
self.weight = nn.Parameter(torch.randn(self.nout, self.nin))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x):
# my_regularization = torch.abs(self.weight).mean().reshape(1)
my_regularization = torch.abs(self.weight).mean()
return torch.nn.functional.linear(x, self.weight), my_regularization
model = MyLinear(10, 1).cuda()
model = nn.DataParallel(model)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01, momentum=0.1)
for i in range(100):
data = torch.randn(100, 10).cuda()
target = torch.randn(100, 1).cuda()
output, my_regularization = model(data)
print(output.shape, my_regularization.shape)
loss = criterion(output, target)
loss = loss + my_regularization
optimizer.zero_grad()
loss.backward()
optimizer.step()
I’ve implemented a special version of Linear layer with customized regularization term.
And my Linear layer return both the output and the regularization term which will be optimized as part of the loss function.
When I employ the nn.DataParallel to perform multi-GPUs training, the program gives an error:
RuntimeError: dimension specified as 0 but tensor has no dimensions
The error indicates that my_regularization
term has no axis.
But when I reshape it from an scalar to vector with .reshape(1)
, another error occurs:
RuntimeError: grad can be implicitly created only for scalar outputs
Because my_regularization
strangely has shape=[2], which I think should be [1].
I have 2 GPUs on board and my_regularization
may be the concatenation of both.
But why don’t the shape of output
change from [100, 1] to [200, 1] as well?
In my oppinion they are both output of MyLinear
layer, so what causes their different behaviours?
Could you folks give me any hints about solving this issue?