Custom loss function: NN not training between epochs


I have a neural network. If I use nn.CrossEntropyLoss() as objective function, the network trains as intended.

I am trying to implement this: Geometric mean - Wikipedia on my own:

class Geometric_mean(nn.Module):
def init(self, weight=None, size_average=True):
super(Geometric_mean, self).init()

def forward(self, output, target, smooth=1): 
    #neural network output
    def add(loss , output, target, u_left , no_of_uniques):
        if u_left < 0:
            loss = loss**(1/no_of_uniques)
            return loss
            #split NN output into two
            #print(f"this is NN output:{output}")
            a1,a2 = output.split(split_size = 1,dim=1) 
            pred = (a2-a1).floor() + 1 <<<----------

            pred_target = torch.column_stack((pred,target))

            #for each values of output_target that are == to indexed member of unique
            class_rows = pred_target[(pred_target[:, 1] == target.unique()[u_left])]     <<<----------
            pred,class_target = class_rows.split(1,dim=1)

            percentage_class_correct =torch.sum(1-(pred-class_target))/len(class_target) <<<----------

            return add(loss + percentage_class_correct,output,target,u_left-1,no_of_uniques)

    no_of_unique = len(target.unique())
    #-1 is here because array starts with 0, while len() returns a normal count
    u_left = no_of_unique-1
    return add(0, output, target, u_left, no_of_unique)

I assume this is where I made my mistake. As this is the only thing I changed.

It does seem to train within an epoch, but kinds of reset itself and training for 10 epochs is like running the same epoch with the same losses 10 times rather than building on top of each other.

E.g. Epoch 001: | Train Loss: 1.21563 | Val Loss: 0.73585 %| Train Acc: 40.167 %| Val Acc: 49.057
Epoch 002: | Train Loss: 1.21563 | Val Loss: 0.73585 %| Train Acc: 40.167 %| Val Acc: 49.057
Epoch 003: | Train Loss: 1.21563 | Val Loss: 0.73585 %| Train Acc: 40.167 %| Val Acc: 49.057


It does compute the loss and everything within an epoch, so at least that part is OK.
I have another loss function I tried implementing that have the same problem also but this is the simpler of the two.

Earlier on, I had made a mistake of messing with the computation graph, but then python complained very loudly and the code would not run at all. This does not throw any error so I am not sure if this is the case.

Any suggestions as to what I can do to fix the code?

Embarrassing but my mistake was due to using the floor function —

it has derivative zero almost everywhere, and thus whatever composite function it is part of will also have derivative zero by the chain rule