The gradient of a leaf node returns None

Hey guys,

I am new here and looking for some help.
I am currently trying to do slightly modification on the set abstraction of PointNet++. The original code are available at Group-free-3D. I want to replace the shareMLP with attention mechanism to test whether it works better.

here is the modified Set abstraction module:

class _PointnetAAModuleBase(nn.Module):

    def __init__(self):
        super().__init__()
        self.npoint = None
        self.groupers = None
        self.multi_head_atten = None
        self.linears = None
        self.norms = None
        self.dropouts = None

    def forward(self, xyz: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Basic class for modified Attention Abstraction module

        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, N, C) tensor of the descriptors of the the features

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
        """

        B, N, C = features.shape

        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        features_flipped = features.transpose(1,2).contiguous()

        ref_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)

        new_xyz = pointnet2_utils.gather_operation(
            xyz_flipped,
            ref_idx
        ).transpose(1, 2).contiguous() if self.npoint is not None else None
        # new_xyz: [B, npoint, 3]

        ref_features = pointnet2_utils.gather_operation(
            features_flipped,
            ref_idx
        ).transpose(1, 2).unsqueeze(2).contiguous().view(B*self.npoint, 1, C) if self.npoint is not None else None
        # ref_features: [B, C, npoint] -> [B, npoint, C] -> [B, npoint, 1, C] -> [B*npoint, 1, C]

        for i in range(len(self.groupers)):
            neighbor_features = self.groupers[i](
                xyz, new_xyz, features_flipped
            ).permute(0,2,3,1).contiguous().view(B*self.npoint, -1, C)
            # new_features: [B, C, npoint, nsample] -> [B, npoint, nsample, C] -> [B*npoint, nsample, C]

            # Local cross attention
            attn_features, attn_weights = self.multi_head_atten[i](
                ref_features,
                neighbor_features,
                neighbor_features
            )
            new_features = attn_features.squeeze(1).contiguous().view(B, self.npoint, C)
             # new_features: [B*npoint, 1, C] -> [B*npoint, C] -> [B, npoint, C] -> [B, C, npoint]

            new_features = self.linears[i](
                new_features
            ) # new_features: [B, npoint, C] -> [B, npoint, mlp[-1]]

            new_features = self.norms[i](
                new_features
            )

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=-1)

class PointnetAAModuleMSG(_PointnetAAModuleBase):
    def __init__(
        self,
        *,
        npoint: int,
        radii: List[float],
        nsamples: List[int],
        mlps: List[List[int]],
        ln: bool = True,
        use_xyz: bool = False, 
        sample_uniformly: bool = False
    ):
        super().__init__()

        assert len(radii) == len(nsamples) == len(mlps)
        """
        Multi-Scale grouping version of Attention Abstraction module

        """

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.linears = nn.ModuleList()
        self.multi_head_atten = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(
                    radius, 
                    nsample,
                    use_xyz=use_xyz, 
                    sample_uniformly=sample_uniformly
                )
                if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
            )

            mlp_spec = mlps[i]
            
            if use_xyz:
                mlp_spec[0] += 3

            self.multi_head_atten.append(nn.MultiheadAttention(embed_dim=mlp_spec[0], num_heads=1, batch_first=True))
            self.linears.append(nn.Linear(mlp_spec[0], mlp_spec[1], bias=False))
            self.norms.append(nn.LayerNorm(mlp_spec[1]))
            self.dropouts.append(nn.Dropout(0.1))

Here is the test code:

xyz = Variable(torch.randn(4, 9, 3).cuda(), requires_grad=True)
    xyz_feats = Variable(torch.randn(4, 9, 6).cuda(), requires_grad=True)

    test_aa_module = PointnetAAModuleMSG(
        npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[6,6], [6,9]]
    )

    test_aa_module.cuda()
    print (test_aa_module)

    for _ in range(1):
        _, new_features = test_aa_module(xyz, xyz_feats)
        print(new_features.shape)
        
        # out = new_features.sum()
        # out.backward()
        # print(out.grad)

        new_features.backward(
            torch.cuda.FloatTensor(*new_features.size()).fill_(1)
        )
        print(new_features.shape)
        print(new_features)
        print(xyz.is_leaf)
        print(xyz.grad)

       # out = new_features.sum()
        # out.backward()
        # print(out.grad)

the gradient of the xyz returns None, however, when I test the orignial set abstraction, it returns the gradient normally. Moreover, when I try to backward the output new_features and check its gradient, it also returns None and says that out is not a leaf node. I don’t know what causes these problem, cause I am a little bit unfamiliar with the autograd part of PyTorch.

Thanks all for your precious advice and help in advance!