Adam optimize detached module with weight_decay and DataParallel

Hi, I meet a strange thing when I use nn.DataParallel and multi-gpu (say, 2) and Adam with weight_dacay not equal to 0, to optimize an easy network, with its conv1 detached from loss function. The result shows that Adam optimizes conv1's weight. When I change weight_decay to 0, there’s no bug. When I run with single gpu, there’s no bug either. Here are the code and result, can anyone help me with this?
(pytorch 0.4 and 1.0 both)

code

#!/user/bin/python
# coding=utf-8

import os
import torch
import torch.nn.functional as F
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 3, 3, 1, 1)
        self.conv2 = nn.Conv2d(3, 3, 3, 1, 1)
        self.init()

    def init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.constant_(m.weight, 1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1.detach())
        return conv1, conv2


BS = 20
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"
bce_loss = lambda a, b: F.binary_cross_entropy_with_logits(a, b)

net = nn.DataParallel(Net().cuda())
opt = torch.optim.Adam(
    filter(lambda p: p.requires_grad, net.parameters()), lr=1e-3,
    betas=(0.9, 0.99), weight_decay=2e-4)

inp = torch.zeros((BS, 3, 64, 64)).cuda()
inp.fill_(1)
gt = torch.zeros((BS, 3, 64, 64)).cuda()
gt.fill_(100)

for _ in range(5):
    conv1, conv2 = net(inp)

    print()
    print('conv1.output:', conv1[0,0,0,0])
    print('conv1.weight:', net.module.conv1.weight[0,0,0,0])
    if net.module.conv1.weight.grad is not None:
        print('conv1.grad  :', net.module.conv1.weight.grad[0,0,0,0])
    print('conv2.weight:', net.module.conv2.weight[0,0,0,0])
    if net.module.conv2.weight.grad is not None:
        print('conv2.grad  :', net.module.conv2.weight.grad[0,0,0,0])

    loss = bce_loss(conv2, gt).cuda()
    print('loss        :', loss.data.item())

    opt.zero_grad()
    loss.backward()
    opt.step()

Result:

conv1.output: tensor(12.0751, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
loss        : -69638.2421875

conv1.output: tensor(12.0621, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(0.9990, device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0010, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-856.6614, device='cuda:0')
loss        : -69637.3984375

conv1.output: tensor(12.0491, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(0.9980, device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0020, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-855.7734, device='cuda:0')
loss        : -69636.3984375

conv1.output: tensor(12.0361, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(0.9970, device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0030, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.8976, device='cuda:0')
loss        : -69635.2890625

conv1.output: tensor(12.0231, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(0.9960, device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0040, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.0023, device='cuda:0')
loss        : -69633.96875

Contrast Experiment: Single GPU

And when

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Result:

conv1.output: tensor(12.0137, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
loss        : -69860.8515625

conv1.output: tensor(12.0137, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1.0010, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.7103, device='cuda:0')
loss        : -69930.796875

conv1.output: tensor(12.0137, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1.0020, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.7103, device='cuda:0')
loss        : -70000.7890625

conv1.output: tensor(12.0137, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1.0030, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.7103, device='cuda:0')
loss        : -70070.734375

conv1.output: tensor(12.0137, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1.0040, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-854.7103, device='cuda:0')
loss        : -70140.640625

Contrast Experiment: Weight Decay

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"
opt = torch.optim.Adam(
    filter(lambda p: p.requires_grad, net.parameters()), lr=1e-3,
    betas=(0.9, 0.99), weight_decay=0)

Result:

conv1.output: tensor(11.8135, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv2.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
loss        : -69543.546875

conv1.output: tensor(11.8135, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0010, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-848.2969, device='cuda:0')
loss        : -69613.203125

conv1.output: tensor(11.8135, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0020, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-848.2969, device='cuda:0')
loss        : -69682.859375

conv1.output: tensor(11.8135, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0030, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-848.2969, device='cuda:0')
loss        : -69752.5

conv1.output: tensor(11.8135, device='cuda:0', grad_fn=<SelectBackward>)
conv1.weight: tensor(1., device='cuda:0', grad_fn=<SelectBackward>)
conv1.grad  : tensor(0., device='cuda:0')
conv2.weight: tensor(1.0040, device='cuda:0', grad_fn=<SelectBackward>)
conv2.grad  : tensor(-848.2969, device='cuda:0')
loss        : -69822.1484375