import torch
import torch.nn as nn
from torch.nn import BatchNorm3d

class Model(nn.Module):
    def __init__(self):
        self.conv1 = nn.Sequential(
            # conv1
            nn.Conv3d(1, 64, 3, padding=1),
            nn.Conv3d(64, 64, 3, padding=1),
        self.convM1 = nn.Sequential(
            # conv1
            nn.Conv3d(1, 64, 3, padding=1),
            nn.Conv3d(64, 64, 3, padding=1),
        self.d1 = nn.Conv3d(64,1,1)
        self.e1 = nn.Conv3d(64,6,1)
    def forward(self, x):
        conv1 = self.conv1(x)
        e1 = self.e1(conv1)
        e1_sum = torch.sum(e1, 1).unsqueeze(1)
        convM1 = self.convM1(e1_sum)
        d1 = self.d1(convM1)
        return {'mask':[d1],

class DiceLoss(nn.Module):
    '''soft dice loss'''
    '''Computes the Dice Loss (dice) as described in'''

    def __init__(self, reduce_axes=[1, 2, 3, 4], smooth=1.0, epsilon=1e-7, final_reduction=torch.mean):
        super(DiceLoss, self).__init__()
        self.reduce_axes = reduce_axes
        self.smooth = smooth
        self.eps = epsilon
        self.final_reduction = final_reduction

    def forward(self, predictions, labels):
        Simple functional form of dice loss
        predictions = nn.Sigmoid()(predictions)
        dice_loss = 1.0 - self.dice_fn(predictions, labels, self.reduce_axes, self.smooth, self.eps)

        if self.final_reduction:
            dice_loss = self.final_reduction(dice_loss)

        return dice_loss

    def dice_fn(self, predictions, labels, reduce_axes, smooth, eps):
        Can accept a soft count relaxation or a true binary input
        Uses the squared denominator form as Milletari demonstrated that its
        gradient behaves better
        intersection = labels * predictions

        intersection = intersection.sum(dim=reduce_axes)
        labels = (labels * labels).sum(dim=reduce_axes)
        predictions = (predictions * predictions).sum(dim=reduce_axes)

        dice = (2.0 * intersection + smooth) / \
               (labels + predictions + smooth + eps)

        return dice

class DicePlusBCE(nn.Module):
    def __init__(self, reduce_axes=[1, 2, 3, 4], smooth=1.0, epsilon=1e-7, final_reduction=torch.mean, pos_weight=None):
        super(DicePlusBCE, self).__init__()
        self.reduce_axes = reduce_axes
        self.smooth = smooth
        self.eps = epsilon
        self.final_reduction = final_reduction
        self.pos_weight = pos_weight
        self.dice_loss = DiceLoss(reduce_axes, smooth, epsilon, final_reduction)
        self.focal_loss = nn.BCEWithLogitsLoss()

    def forward(self, predictions, labels):
        dice_loss = self.dice_loss(predictions, labels)
        bce_loss = self.focal_loss(predictions, labels)
        return dice_loss + bce_loss

class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight=True):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target):
        # target_weight = torch.ones_like(target)
        # target_weight[target > 0] = 100
        # target_weight = target_weight.view(target.shape[0], 6, -1)
        # import pdb;pdb.set_trace()
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss += self.criterion(
                loss += 1 * self.criterion(heatmap_pred, heatmap_gt)

        return loss / num_joints

model = Model()
model_parallel = torch.nn.DataParallel(model).cuda()
mask_loss = DicePlusBCE()
exp_loss = JointsMSELoss()
x = torch.randn(2,1,48,256,256)
x = torch.autograd.Variable(x).cuda()
gt = torch.autograd.Variable(torch.zeros_like(x)).cuda()
exp_gt = torch.autograd.Variable(torch.zeros(2,6,48,256,256)).cuda()
outs = model_parallel(x)

loss_m = mask_loss(outs['mask'][-1], gt)
loss_exp = exp_loss(outs['exp'][-1], exp_gt)
loss = loss_exp+loss_m

when I execute this code am getting CUDNN_STATUS_MAPPING_ERROR but when i change the convM1 = self.convM1(e1_sum) to convM1 = self.convM1(x) am not getting any error and the code is being executed normally.
How do i fix this ?
pytorch version: 1.4.0
cuda version: 10.0
cudnn version:

I tried to reproduce this issue with CUDA10.2 + cudnn7.6.5.32 with PyTorch build from master as well as the 1.4.0 binaries with CUDA10.1 and CUDA10.0 and all setups run correctly.

How did you install PyTorch? Do you get any error when you disable cudnn via torch.backends.cudnn.enabled = False?

I installed pytorch using conda. when i disable cudnn i get this error cuda error CUBLAS_STATUS_NOT_INITIALIZED when calling cublasCreate(handle)

i have 4X48GB Quadro RTX 8000. which should be sufficiently large to handle batch size 2.

I don’t think it’s an OOM issue.
Could you try to create a new virtual (conda) environment and reinstall the binaries?
I’ll try to reproduce this issue on an RTX 8000.

No reproduction on a machine with two RTX 8000 using PyTorch master built with CUDA10.2 and nightly binaries with 10.1.

i created a new virtual environment with conda and installed pytorch via source and still am getting the same error. I ran the same code in Titan V (2 gpu and 12gb each) and i am able to execute with no issues.

Is the Titan V plugged into the same box as the RTX 8000?
Could you try to remove the ~/.nv/ folder?

No Titan V is from a different box. i removed ~/.nv/ folder and still the error remains.

when i replace convM1 = self.convM1(e1_sum) to convM1 = self.convM1(x) i don’t get any error. I am trying to understand why it happens when i pass e1_sum to self.convM1 i get the error.

It’s hard to tell, what went wrong, when we cannot reproduce the error and thus cannot debug it.
Based on the CUBLAS_STATUS_NOT_INITIALIZED error I assumed something might be wrong with the installation, but since you reinstalled the binaries as well as built from source, this is unlikely.

This is happening only in RTX 8000 gpu and it is working fine in Titan V. so do i have to create an issue in pytorch github to keep a track of this ?

Sure, you can create an issue and tag me there so that we can continue the discussion there.