Creating custom nn.Module for Softmargin Softmax

Hello all,

I have been trying to recreate the Softmargin Softmax function found here: Softmargin Softmax paper.
I took the current Softmax function in Activation.py and changed the the forward step with my custom Python calculation.

def forward(self, input):
        m = 0
        input = input.exp()
        sumexp = torch.sum(input,dim=1)
        expm = math.exp(-m)
    
        # calculate softmargin softmax
        for x in range (input.size(0)):
            for y in range(input.size(1)):
                input[x,y] = (input[x,y]*expm)/(sumexp[x] - input[x,y] + (input[x,y] * expm))
         
        #normalize the weights
        sumnorm = torch.sum(input,dim=1)
        for x in range (input.size(0)):
            for y in range(input.size(1)):
                input[x,y] = input[x,y]/sumnorm[x]

        return input

If I now run my model, it does not crash, but seems to keep calculating something whitout progressing in the model.
Did I look over some important step? Or is it just so inefficient?
Any help is appreciated.

edit

The function appears to be running very,very slow.
Does somebody know how to speed it up?

Prefer not to use for loops, try to vectorize your code as much as possible.

Refer this old question I asked for checking my implementation of softargmax which I believe you are lookin for. It’s pretty decent and fast…

1 Like

Thank you very much the pointer, I was able to redo do the function and now it runs (almost) as fast as the original softmax! For reference here is the function (used for a 3D tensor and calculated over the rows):

    def forward(self, input):
        m = 0
        expm = math.exp(-m) #calculate m
        input = torch.exp(input)
        
        #keep track of the original dimensions
        s0 = input.size(0)
        s1 = input.size(1)
        s2 = input.size(2)
        
        #input = input.transpose(1,2).contiguous().view(-1,s1) #use if you want to calculate over columns instead of rows
        input = input.view(-1,s2)
        
        #calculate softmargin softmax
        expsum = torch.sum(input, dim=1).view(-1,1) # sum and reshape for softmax
        input = (input*expm)/(expsum - input + (input*expm))
        
        #normalize
        normsum = torch.sum(input, dim=1).view(-1,1) #sum and reshape for normalization
        input = input/normsum
        
        input = input.view(s0,s1,s2)
        #input = input.view(s0,s2,s1).transpose(1,2).contiguous() #  #use if you want to calculate over columns instead of rows
        
        return input