Backward() compatibility with custom complex CrossEntropyLoss

I started discussion this topic in issue #81950. Basically, I have designed a custom CrossEntropyLoss function to work with complex-valued data (remote sensing, signal processing, etc.), and I designed a simple FNN. However, I’m getting an error getting the autograd to propagate: AttributeError: CrossEntropyLoss object has no attribute backward

Any thoughts?

The code is the following:
#Cross entropy loss function

class MyComplexCrossEntropyLoss(nn.Module):
    
    def __init__(self):
        super(MyComplexCrossEntropyLoss, self).__init__()

    def forward(self, inputs, targets):       
        
        if torch.is_complex(inputs):
            real_loss = nn.CrossEntropyLoss(inputs.real, targets)
            imag_loss = nn.CrossEntropyLoss(inputs.imag, targets)
            return (real_loss + imag_loss)/2
        else:
            return nn.CrossEntropyLoss(inputs, targets)

#Trainning Batch:

def train_batch(X, y, model, optimizer, criterion, **kwargs):
    "
    X (n_examples x n_features)
    y (n_examples): gold labels
    model: a PyTorch defined model
    optimizer: optimizer used in gradient step
    criterion: loss function
    "
    optimizer.zero_grad()
    out = model(X, **kwargs)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    return loss.item()

#Main

 ...
    model = FeedforwardNetwork(
            n_classes,
            n_feats,
            opt.hidden_sizes,
            opt.layers,
            opt.activation,
            opt.dropout
        )

    # get an optimizer
    optims = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD}

    optim_cls = optims[opt.optimizer]
    optimizer = optim_cls(
        model.parameters(), lr=opt.learning_rate, weight_decay=opt.l2_decay
    )

    # get a loss criterion
    criterion = MyComplexCrossEntropyLoss()#nn.L1Loss()#

    # training loop
    epochs = torch.arange(1, opt.epochs + 1)
    train_mean_losses = []
    valid_accs = []
    train_losses = []
    for ii in epochs:
        print('Training epoch {}'.format(ii))
        for X_batch, y_batch in train_dataloader:
            loss = train_batch(
                X_batch, y_batch, model, optimizer, criterion)
            train_losses.append(loss)

        mean_loss = torch.tensor(train_losses).mean().item()
        print('Training loss: %.4f' % (mean_loss))

        train_mean_losses.append(mean_loss)
        valid_accs.append(evaluate(model, dev_X, dev_y))
        print('Valid acc: %.4f' % (valid_accs[-1]))

    print('Final Test acc: %.4f' % (evaluate(model, test_X, test_y)))


You would have to create the object first and then call it:

loss = nn.CrossEntropyLoss(inputs, targets) # wrong

# right
criterion = nn.CrossEntropyLoss()
loss = criterion(input, targets)