BCE Loss Runtime Error


I’m trying to train an adversarial network and I’m using BCELoss from PyTorch. I’ve provided the discriminator network, training code, error message snippet below. I’m training the network only for 5 epochs and there is no error generated for the initial 4 epochs but stuck into a runtime error at the fifth epoch. Any suggestions, please?

Discriminator Network

class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.Dconv1 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 5)),

    self.Dconv2 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 3)),

    self.Dfc1 = nn.Sequential(
        nn.Linear(in_features=1536, out_features=256),
        nn.Linear(in_features=256, out_features=1),

def _initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            if m.bias is not None:
        elif isinstance(m, nn.BatchNorm2d):
        elif isinstance(m, nn.Linear):
            n = m.weight.size(1)
            m.weight.data.normal_(0, 0.01)

def forward(self, input):
    out = self.Dconv1(input)
    out = self.Dconv2(out)
    out = out.reshape(-1, 512*1*3)
    out = self.Dfc1(out)
    return out

Training Code Snippet

        s1_length = len(s1_source)
        s2_length = len(s2_source)
        t_length= len(s1_target)
        logging.warning("Iteration: %d, S1 length: %d, S2 length: %d, Target length: %d", i, s1_length, s2_length, t_length)
        s1_error_fake = loss(s1_source, ones_target(s1_length))
        s1_error_real = loss(s1_target, zeros_target(t_length))
        s1_t_dis_loss = s1_error_fake + s1_error_real
        s2_error_fake = loss(s2_source, ones_target(s2_length))
        s2_error_real = loss(s2_target, zeros_target(t_length))
        s2_t_dis_loss = s2_error_fake + s2_error_real
        logging.warning("S1 Disc loss: %s, S2 Disc Loss: %s", s1_t_dis_loss.data, s2_t_dis_loss.data)

Error Message

146it [00:02, 53.81it/s]
146it [00:02, 56.52it/s]
146it [00:02, 57.55it/s]
42it [00:00, 58.23it/s]

RuntimeError Traceback (most recent call last)
63 logging.warning(“Iteration: %d, S1 length: %d, S2 length: %d, Target length: %d”, i, s1_length, s2_length, t_length)
—> 65 s1_error_fake = loss(s1_source, ones_target(s1_length))
66 s1_error_real = loss(s1_target, zeros_target(t_length))
67 s1_t_dis_loss = s1_error_fake + s1_error_real

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
510 @weak_script_method
511 def forward(self, input, target):
–> 512 return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
2112 return torch._C._nn.binary_cross_entropy(
-> 2113 input, target, weight, reduction_enum)

RuntimeError: reduce failed to synchronize: device-side assert triggered


Hi! May be this helps:

You can also try to run your code on CPU, you will receive much more informative error message then.

Thank you for the link. I actually read through that link yesterday and as you can see in my discriminator class I’ve used a Sigmoid at the final output. I also used the target tensors using the following code snippet -

def ones_target(size):
Tensor containing ones, with shape = size
data = Variable(torch.ones(size, 1)).cuda(gpu_id)
return data

def zeros_target(size):
Tensor containing zeros, with shape = size
data = Variable(torch.zeros(size, 1)).cuda(gpu_id)
return data

I’m pretty sure, I’ve done all the fixes that link suggests. I’m now confused about what could be wrong. What really annoying is that the error generates after running few epochs!

Yes, it is confusing.
Did you try to print out the loss? May be it is goes to infinity or vanishes causing the overflow?