Attempt to perform Gradnorm with CSwin results in "Tensor not part of graph"

Hi!, I’ve been trying to implement this proposed method GradNorm using a CSwin backbone.

In specific I’m trying to perform the grads of my loss against the layer stage3 as follows

Option A torch.autograd.grad(loss_drinks, model.stage3.parameters(), retain_graph=True, create_graph=True)

Option B torch.autograd.grad(loss_drinks, model.stage3.parameters(), retain_graph=True)

However, both options result in the same error

*** RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

As I understand there must be a step that’s detaching a tensor from the graph but I’m not sure what actions can trigger this to work around that.

** Note: performing this against the head layer works as expected and calculates the grads

Any feedback is appreciated! Thanks!

I did some digging around with torchviz and debugging from my side

I tracked to here the place I believe detaches from the graph
img = self.norm1(x)

After checking the output from norm1, img.requires_grad is False. However before norm1 is executed, x.requires_grad is True

As I understand this, LayerNorm operation shouldn’t detach the tensor. Not sure if this is the expected behavior. If so, any suggestions to work around this would be helpful.

Thanks!

Yes, you are right that a plain nn.LayerNorm layer should not detach the input tensor.
Could you post a minimal and executable code snippet showing this behavior, please?

@ptrblck thanks for the reply, here is a minimal setup of what I currently have

In this example, the number of tasks set is 2.
This is how my trainer looks like. The section i get the error is indicated as “FAILURE POINT”

tasks=2

optimizer = timm.optim.optim_factory.create_optimizer(args, model)
lambda_weights = torch.ones((tasks, ), requires_grad=True).cuda()
lambda_optim = timm.optim.optim_factory.create_optimizer(args, [torch.nn.Parameter(lambda_weights)])

for idx_epoch in epochs:
    for idx_batch, item in enumerate(ds_loader):
        img = item['img'].requires_grad_(True)
        target_a = item['target_a']
        target_b = item['target_b']

        out_a, out_b = model(img)

        loss_a = criterion(target_a, out_a)
        loss_b = criterion(target_b, out_b)

        loss_tasks = torch.stack([loss_a, loss_b])

        if idx_epoch == 0:
            initial_loss = loss_tasks

        weighted_loss = lambda_weights * loss_tasks
        loss = weighted_loss.sum()

        lambda_weights.retain_grad(True)
        loss_tasks.retain_grad(True)

        loss.backward(retain_graph=True)

        norms = []
        # FAILURE POINT
        for w_i, l_i in zip(lambda_weights, loss_tasks):
            local_grad = torch.autograd.grad(l_i, model.stage3.parameters(), retain_graph=True)
            norms.append(torch.norm(w_i * local_grad))

        norms = torch.stack(norms)
        nw = norms.mean()
        with torch.no_grad():
            # loss ratios
            loss_ratios = loss_tasks / initial_loss
            # inverse training rate r(t)
            inverse_train_rates = loss_ratios / loss_ratios.mean()
            constant_term =  nw * (inverse_train_rates ** alpha)
        
        # compute Lgrad
        lgrad = (norms - constant_term).abs().sum()
        lambda_weights.grad  = torch.autograd.grad(lgrad, lambda_weights)

        optimizer.step()
        gradnorm_optim.step()

        # Renormalize
        with torch.no_grad():
            renormalize = args.tasks / lambda_weights.sum()
            lambda_weights *= renormalize

My model keeps the same CSwin setup with the different that stage4 corresponds to an array of “stages” according to number of task heads.

class CSWinTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=[1000], embed_dim=96, depth=[2,2,6,2], split_size = [3,5,7],
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False):
        super().__init__()
        self.use_chk = use_chk
        self.head_class_count = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        heads=num_heads

        self.stage1_conv_embed = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
            Rearrange('b c h w -> b (h w) c', h = img_size//4, w = img_size//4),
            nn.LayerNorm(embed_dim)
        )

        curr_dim = embed_dim
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))]  # stochastic depth decay rule
        self.stage1 = nn.ModuleList([
            CSWinBlock(
                dim=curr_dim, num_heads=heads[0], reso=img_size//4, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth[0])])

        self.merge1 = Merge_Block(curr_dim, curr_dim*2)
        curr_dim = curr_dim*2
        self.stage2 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[1], reso=img_size//8, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:1])+i], norm_layer=norm_layer)
            for i in range(depth[1])])
        
        self.merge2 = Merge_Block(curr_dim, curr_dim*2)
        curr_dim = curr_dim*2
        temp_stage3 = []
        temp_stage3.extend(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[2], reso=img_size//16, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:2])+i], norm_layer=norm_layer)
            for i in range(depth[2])])

        self.stage3 = nn.ModuleList(temp_stage3)
        
        self.merge3 = Merge_Block(curr_dim, curr_dim*2)
        curr_dim = curr_dim*2

        self.heads = nn.ModuleList([
            nn.ModuleList([
                Merge_Block(curr_dim, curr_dim * 2),
                nn.Sequential(
                    *[CSWinBlock(
                        dim=curr_dim * 2, num_heads=heads[3], reso=img_size // 32, mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1],
                        drop=drop_rate, attn_drop=attn_drop_rate,
                        drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True)
                        for i in range(depth[-1])]),
                norm_layer(curr_dim * 2),
                nn.Linear(curr_dim * 2, head_class_count)
            ])
            for head_class_count in self.heads_class_counts])

        trunc_normal_(self.head.weight, std=0.02)
        self.apply(self._init_weights)

    ...

    def forward_features(self, x):
        B = x.shape[0]
        x = self.stage1_conv_embed(x)
        for blk in self.stage1:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)

        for pre, blocks in zip([self.merge1, self.merge2],
                               [self.stage2, self.stage3]):
            x = pre(x)
            for blk in blocks:
                if self.use_chk:
                    x = checkpoint.checkpoint(blk, x)
                else:
                    x = blk(x)

        results = []
        for head_idx, head_unit in enumerate(self.heads):
                merge, stage, norm, head = head_unit

                head_x = merge(x)
                head_x = stage(head_x)

                head_x = norm(head_x)

                head_x = torch.mean(head_x, dim=1)
                # if self.aggregate is True:
                #     logits += [head_x]

                head_x = head(head_x)
                # results[head_idx, :, :head_x.shape[1]] = head_x
                results += [head_x]
        
        return results

    def forward(self, x):
        x = self.forward_features(x)
        return x

Your code is unfortunately not executable.

@ptrblck Hi, I pushed a minimal setup in github.

While doing the minimal setup noticed it was initially working and struggled to make it reproduce the issue compared to our trainer implementation but found the main issue for this.

The model makes use of checkpoint so the gradient computation is skipped at forward pass. One way to avoid this is to just set use_chk variable to False.
https://pytorch.org/docs/stable/checkpoint.html

Ideally I would like to keep the using checkpoint but not familiar if there is a way achieve both functionalities, calculate gradients based on intermediate layers and keeping checkpoint enabled.