On both nightly and 2.6 the following code fails due to a data-dependent expression inside cross_entropy
import torch
import torch.nn.functional as F
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@torch.compile(fullgraph=True)
def loss(logits, target, mask):
logits = logits[mask, :]
target = target[mask]
return F.cross_entropy(logits, target)
loss(torch.rand(10, 4), torch.randint(0, 4, (10,)), torch.ones(10, dtype=torch.bool))
is it possible to fully compile this?