Torch.compile() error with detach() and with torch.no_grad()

When I using torch.compile(), I encountered the following issues.
The forward pass of my model is:

    def forward(self, x, forward_pass='default', sharpen=True):
        if forward_pass == 'dm':
            features = self.backbone(x)
            out = self.dm_head(features)
            return out

        elif forward_pass == 'pcl':
            features = self.backbone(x)
            q = self.projector(features)
            q = F.normalize(q, dim=1)
            prototypes = self.prototypes.clone().detach()
            if sharpen:
                logits_proto = torch.mm(q, prototypes.t()) / self.temprature
            else:
                logits_proto = torch.mm(q, prototypes.t())
            return q, logits_proto

        elif forward_pass == 'all':
            features = self.backbone(x)
            q = self.projector(features)
            q = F.normalize(q, dim=1)
            prototypes = self.prototypes.clone().detach()
            if sharpen:
                logits_proto = torch.mm(q, prototypes.t()) / self.temprature
            else:
                logits_proto = torch.mm(q, prototypes.t())
            out_dm = self.dm_head(features)
            return out_dm, logits_proto, q

The backbone is ResNet18, and dm_head and projector are several dense layers.
With the forward_pass argument is ‘all’, I want to calculate the ‘dm’ and ‘pcl’ simutaneously.


In the training procedure, I have two different ways of writing.
The first one is using ‘all’, and another is using ‘dm’ and ‘pcl’.

# net and net2 are two same networks, train net, and fix net2
net.train()
net2.eval()

with torch.no_grad():
    size_x1, size_x2, size_u1, size_u2 = inputs_x1.size(0), inputs_x2.size(0), inputs_u1.size(0), inputs_u2.size(0)
    inputs = torch.cat([inputs_x1, inputs_x2, inputs_u1, inputs_u2], dim=0)

    # using 'all':
    dm_outputs_1, proto_outputs_1, features_1 = net(inputs, forward_pass='all')
    dm_outputs_2, proto_outputs_2, features_2 = net2(inputs, forward_pass='all')

    # or using 'dm' + 'pcl' as following to replace the above two lines:
    # dm_outputs_1 = net(inputs, forward_pass='dm')
    # dm_outputs_2 = net2(inputs, forward_pass='dm')
    # features_1, proto_outputs_1 = net(inputs, forward_pass='pcl')
    # features_2, proto_outputs_2 = net2(inputs, forward_pass='pcl')

    dm_o_x11, dm_o_x12, dm_o_u11, dm_o_u12 = torch.split(dm_outputs_1, [size_x1, size_x2, size_u1, size_u2], dim=0)
    dm_o_x21, dm_o_x22, dm_o_u21, dm_o_u22 = torch.split(dm_outputs_2, [size_x1, size_x2, size_u1, size_u2], dim=0)

labels_x_soft = torch.zeros(batch_size, args.num_classes, device=device).scatter_(1, labels_x.view(-1, 1), 1)
w_x = w_x.view(-1, 1).type(torch.FloatTensor).to(device)

targets_u = co_guessing(dm_o_u11, dm_o_u12, dm_o_u21, dm_o_u22, args.T)
targets_x = co_refinement(dm_o_x11, dm_o_x12, labels_x_soft, w_x, args.T)

variable_dict = {'inputs_x1': inputs_x1, 'inputs_x2': inputs_x2, 'inputs_u1': inputs_u1,
                 'inputs_u2': inputs_u2, 'targets_x': targets_x, 'targets_u': targets_u}
loss_dm = dividemix_train_step(args, net, 'dm', variable_dict, dm_criterion, batch_size, batch_idx, num_iter, epoch, device)

The co_guessing, co_refinement, dm_criterion and dividemix_train_step functions are like follows (dm_criterion is an instance of class SemiLoss):

def co_guessing(outputs_u11, outputs_u12, outputs_u21, outputs_u22, T):
    with torch.no_grad():
        # label co-guessing of unlabeled samples
        pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) +
              torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
        ptu = pu ** (1 / T)  # temperature sharpening

        targets_u = ptu / ptu.sum(dim=1, keepdim=True)  # normalize
        targets_u = targets_u.detach()
    return targets_u


def co_refinement(outputs_x11, outputs_x12, labels_x_soft, w_x, T):
    with torch.no_grad():
        # label refinement of labeled samples
        px = (torch.softmax(outputs_x11, dim=1) + torch.softmax(outputs_x12, dim=1)) / 2
        px = w_x * labels_x_soft + (1 - w_x) * px
        ptx = px ** (1 / T)  # temperature sharpening

        targets_x = ptx / ptx.sum(dim=1, keepdim=True)  # normalize
        targets_x = targets_x.detach()
    return targets_x


class SemiLoss(nn.Module):
    # Based on the implementation of SupContrast
    def __init__(self):
        super(SemiLoss, self).__init__()

    def linear_rampup(self, lambda_u, current, warm_up, rampup_length=16):
        current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
        return lambda_u * float(current)

    def forward(self, outputs_x, targets_x, outputs_u, targets_u, lambda_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u) ** 2)

        return Lx, Lu, self.linear_rampup(lambda_u, epoch, warm_up)

def dividemix_train_step(args, net, forward, variable_dict, criterion, batch_size, batch_idx, num_iter, epoch, device):
    # Unpack variables
    inputs_x1, inputs_x2, inputs_u1, inputs_u2 = variable_dict['inputs_x1'], variable_dict['inputs_x2'], variable_dict['inputs_u1'], variable_dict['inputs_u2']
    targets_x, targets_u = variable_dict['targets_x'], variable_dict['targets_u']

    # mixmatch
    l = np.random.beta(args.alpha, args.alpha)
    l = max(l, 1 - l)

    all_inputs = torch.cat([inputs_x1, inputs_x2, inputs_u1, inputs_u2], dim=0)
    all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

    idx = torch.randperm(all_inputs.size(0))

    input_a, input_b = all_inputs, all_inputs[idx]
    target_a, target_b = all_targets, all_targets[idx]

    mixed_input = l * input_a + (1 - l) * input_b
    mixed_target = l * target_a + (1 - l) * target_b

    if forward == 'dm':
        logits = net(mixed_input, forward_pass=forward)
    elif forward == 'pcl':
        _, logits = net(mixed_input, forward_pass=forward)
    else:
        raise NotImplementedError
    logits_x = logits[:batch_size * 2]
    logits_u = logits[batch_size * 2:]

    Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2], logits_u, mixed_target[batch_size * 2:],
                             args.lambda_u, epoch + batch_idx / num_iter, args.warm_up)

    # regularization
    prior = torch.ones(args.num_classes) / args.num_classes
    prior = prior.to(device)
    pred_mean = torch.softmax(logits, dim=1).mean(0)
    penalty = torch.sum(prior * torch.log(prior / pred_mean))

    loss = Lx + lamb * Lu + penalty
    return loss

When I use ‘all’ instead of ‘dm’+‘pcl’, this code can successfully run. But when using ‘dm’+‘pcl’, there is an error:

Traceback (most recent call last):
  File "/home/lxy/Documents/ClusterTeaching/ClusterTeaching.py", line 194, in <module>
    main()
  File "/home/lxy/Documents/ClusterTeaching/ClusterTeaching.py", line 165, in main
    uniform_proto_train(args, epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader, semi_loss, ce_loss, info_nce_loss, meta_info, device)
  File "/home/lxy/Documents/ClusterTeaching/utils/train_utils.py", line 692, in uniform_proto_train
    loss_dm = dividemix_train_step(args, net, 'dm', variable_dict, dm_criterion, batch_size, batch_idx, num_iter, epoch, device)
  File "/home/lxy/Documents/ClusterTeaching/utils/train_utils.py", line 102, in dividemix_train_step
    logits = net(mixed_input, forward_pass=forward)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/lxy/Documents/ClusterTeaching/models/models.py", line 76, in forward
    def forward(self, x, forward_pass='default', sharpen=True):
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2819, in forward
    return compiled_fn(full_args)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
    return f(*args)
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1898, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1247, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/opt/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 248, in run
    return model(new_inputs)
  File "/tmp/torchinductor_lxy/wo/cwoko77dmlrgqfxjyv3jpalqkko7ufkqxln7tjwcrwacduz4psjx.py", line 1531, in call
    extern_kernels.addmm(arg55_1, as_strided(buf109, (256, 512), (512, 1)), as_strided(arg54_1, (512, 10), (1, 512)), alpha=1, beta=1, out=buf110)
RuntimeError: addmm(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

By the way, there is no error without using torch.compile().
And also, when using detach() to dm_outputs_1 and dm_outputs_1 instead of with torch.no_grad(), there is no error, too.
But I still think there is something wrong, because although it run successfully with the above two methods, the accuracy is lower than the baseline.

Hmm, that seems like a bug. Can you file an issue? Sign in to GitHub · GitHub

Context: inductor should probably be disabling autograd when it runs the generated code.

Also - do you have a runnable, E2E repro that can be used for debugging?

If you’re able to share the codegen’d file, that would probably be helpful too (from your stack trace it’s at /tmp/torchinductor_lxy/wo/cwoko77dmlrgqfxjyv3jpalqkko7ufkqxln7tjwcrwacduz4psjx.py)

Thank you for your response!
Here is the github issue Torch.compile() error with detach() and with torch.no_grad() · Issue #99616 · pytorch/pytorch (github.com)
And here is the codegen’d file cwoko77dmlrgqfxjyv3jpalqkko7ufkqxln7tjwcrwacduz4psjx.py (github.com)