Loss increasing dramatically when running with multiple GPU

Hello,
I run my model on a single GPU, the result of the test set is normal. But when I change to run on multiple GPU, the metrics evaluated on the test set (loss, RMSE,…) are not stable, they increase dramatically over time. I wonder why this problem happens, I thought the result must be the same in both cases.

Pytorch version: 0.4.1.

Best,
Khang Truong.

Depending on the batch size (and potentially other hyperparameters you’ve changed), the output might not be exactly the same.
I would recommend to update to the latest stable PyTorch release (1.5.0) and rerun the script again.
0.4.1 is quite old by now and a lot of fixes went into the code since then. :wink:

Thanks for your quick response. I’ll try to test with a newer version version. Can I install pytorch version 1.5.0 along with cuda version 9.0?

You would have to build from source for CUDA9.0, as the 1.5 binaries ship with CUDA9.2, 10.1, and 10.2.

I tried to upgrade pytorch to 1.1.0 and re-run my code. But the problem is still happening. I don’t change any hyperparameters, only change running from a single GPU to multiple GPU by using torch.DataParallel(). Is there any reason for this?

If you didn’t change any hyperparameters, I assume the batch size is also the same as for the single GPU run.
If that’s the case, your training might suffer, if you are using batchnorm layers, as their running estimates might be off due to the lower batch size on each device.
DistributedDataParallel provides SyncBatchNorm, which synchronizes these stats between the processes.

PyTorch 1.1.0 is also not really new. Are you having trouble updating to the latest version?
If so, could you post the issues, so that we can help with it?

My model does not use any BatchNorm layers. However, I implemented a normalized convolution like below, and my network only utilizes this operation

# The proposed Normalized Convolution Layer
class NConv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, pos_fn='softplus',
                 init_method='k', stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        # Call _ConvNd constructor
        super(NConv2d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, False, _pair(0), groups, bias, padding_mode)
        
        self.eps = 1e-20
        self.pos_fn = pos_fn
        self.init_method = init_method
        
        # Initialize weights and bias
        self.init_parameters()
        
        if self.pos_fn is not None:
            EnforcePos.apply(self, 'weight', pos_fn)
        
    def forward(self, data, conf):
        
        # Normalized Convolution
        denom = F.conv2d(conf, self.weight, None, self.stride,
                        self.padding, self.dilation, self.groups)        
        nomin = F.conv2d(data*conf, self.weight, None, self.stride,
                        self.padding, self.dilation, self.groups)        
        nconv = nomin / (denom+self.eps)

        # Add bias
        b = self.bias
        sz = b.size(0)
        b = b.view(1,sz,1,1)
        b = b.expand_as(nconv)
        nconv += b
        
        # Propagate confidence
        cout = denom
        sz = cout.size()
        cout = cout.view(sz[0], sz[1], -1)
        
        k = self.weight
        k_sz = k.size()
        k = k.view(k_sz[0], -1)
        s = torch.sum(k, dim=-1, keepdim=True)        

        cout = cout / s
        cout = cout.view(sz)
        
        return nconv, cout
        
        
# Non-negativity enforcement class        
class EnforcePos(object):
    def __init__(self, pos_fn, name):
        self.name = name
        self.pos_fn = pos_fn


    @staticmethod
    def apply(module, name, pos_fn):
        fn = EnforcePos(pos_fn, name)
        
        module.register_forward_pre_hook(fn)                    

        return fn

    def __call__(self, module, inputs):
       if module.training:
            weight = getattr(module, self.name)
            weight.data = self._pos(weight).data
       else:
            pass

    def _pos(self, p):
        pos_fn = self.pos_fn.lower()
        if pos_fn == 'softmax':
            p_sz = p.size()
            p = p.view(p_sz[0],p_sz[1], -1)
            p = F.softmax(p, -1)
            return p.view(p_sz)
        elif pos_fn == 'exp':
            return torch.exp(p)
        elif pos_fn == 'softplus':
            return F.softplus(p, beta=10)
        elif pos_fn == 'sigmoid':
            return F.sigmoid(p)
        else:
            print('Undefined positive function!')
            return

You can see that this implementation used class EnforcePos to force the weights non-negative. I wonder that this implementation might cause the problem as I mentioned in this post, but I’m not sure. Do you have any suggestions for implementing non-negative constraint?

I have some problems when building pytorch 1.5.0 from source because my CUDA version is 9.0. Anyway, I’ll try to update to lastest version and re-test again!

Using the .data attribute is not recommended and might yield unwanted side effects, so this line of code might be dangerous:

weight.data = self._pos(weight).data

You could probably use .copy_ and wrap it in a with torch.no_grad() block, if Autograd raises an issue.

Besides that I cannot see any obvious errors.

I did as you suggested. It raised this error
RuntimeError: diff_view_meta->output_nr_ == 0 ASSERT FAILED at /opt/conda/conda-bld/pytorch_1556653183467/work/torch/csrc/autograd/variable.cpp:209, please report a bug to PyTorch.
How to fix this?

Could you post an executable code snippet to reproduce this error?

In the EnforcePos above, I changed the line weight.data = self._pos(weight).data) to:

with torch.no_grad():
    weight.copy_(self._pos(weight))

Next I made a simple network as follows:

class NormCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.nconv1 = NConv2d(1, 2, (5, 5), pos_fn='softplus', padding=2)
        self.nconv2 = Nconv2d(2, 2, (5, 5), pos_fn='softplus', padding=2)

    def forward(self, x, c):
        xout, cout = self.nconv1(x, c)
        xout, cout = self.nconv2(xout, cout)
        return xout, cout

The testing code is given by:

model = NormCNN()
model = torch.DataParallel(model)
model = model.cuda()

x = torch.rand(4, 1, 480, 752).cuda()
c = (x > 0.5).float().cuda()
xo, co = model(x, c)

And then, I got the error as above!

Thanks for the code.
It seems your approach is similar to weight_norm, which uses the setattr and getattr to modify the weight parameter so you could try to adapt WeightNorm to your use case.
Alternatively, it might be easier to use the functional API, which would define a self.weight parameter, manipulate it, and use it via F.conv2d.

Let me know, if this would work.

I followed the implementation of weight_norm to correct my function. It works well now. I can train on multiple GPUs without increasing loss anymore.

Thanks for your help!

1 Like