I started discussion this topic in issue #81950. Basically, I have designed a custom CrossEntropyLoss function to work with complex-valued data (remote sensing, signal processing, etc.), and I designed a simple FNN. However, I’m getting an error getting the autograd to propagate: AttributeError: CrossEntropyLoss object has no attribute backward
Any thoughts?
The code is the following:
#Cross entropy loss function
class MyComplexCrossEntropyLoss(nn.Module):
def __init__(self):
super(MyComplexCrossEntropyLoss, self).__init__()
def forward(self, inputs, targets):
if torch.is_complex(inputs):
real_loss = nn.CrossEntropyLoss(inputs.real, targets)
imag_loss = nn.CrossEntropyLoss(inputs.imag, targets)
return (real_loss + imag_loss)/2
else:
return nn.CrossEntropyLoss(inputs, targets)
#Trainning Batch:
def train_batch(X, y, model, optimizer, criterion, **kwargs):
"
X (n_examples x n_features)
y (n_examples): gold labels
model: a PyTorch defined model
optimizer: optimizer used in gradient step
criterion: loss function
"
optimizer.zero_grad()
out = model(X, **kwargs)
loss = criterion(out, y)
loss.backward()
optimizer.step()
return loss.item()
#Main
...
model = FeedforwardNetwork(
n_classes,
n_feats,
opt.hidden_sizes,
opt.layers,
opt.activation,
opt.dropout
)
# get an optimizer
optims = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD}
optim_cls = optims[opt.optimizer]
optimizer = optim_cls(
model.parameters(), lr=opt.learning_rate, weight_decay=opt.l2_decay
)
# get a loss criterion
criterion = MyComplexCrossEntropyLoss()#nn.L1Loss()#
# training loop
epochs = torch.arange(1, opt.epochs + 1)
train_mean_losses = []
valid_accs = []
train_losses = []
for ii in epochs:
print('Training epoch {}'.format(ii))
for X_batch, y_batch in train_dataloader:
loss = train_batch(
X_batch, y_batch, model, optimizer, criterion)
train_losses.append(loss)
mean_loss = torch.tensor(train_losses).mean().item()
print('Training loss: %.4f' % (mean_loss))
train_mean_losses.append(mean_loss)
valid_accs.append(evaluate(model, dev_X, dev_y))
print('Valid acc: %.4f' % (valid_accs[-1]))
print('Final Test acc: %.4f' % (evaluate(model, test_X, test_y)))