# Creating custom nn.Module for Softmargin Softmax

Hello all,

I have been trying to recreate the Softmargin Softmax function found here: Softmargin Softmax paper.
I took the current Softmax function in Activation.py and changed the the forward step with my custom Python calculation.

``````def forward(self, input):
m = 0
input = input.exp()
sumexp = torch.sum(input,dim=1)
expm = math.exp(-m)

# calculate softmargin softmax
for x in range (input.size(0)):
for y in range(input.size(1)):
input[x,y] = (input[x,y]*expm)/(sumexp[x] - input[x,y] + (input[x,y] * expm))

#normalize the weights
sumnorm = torch.sum(input,dim=1)
for x in range (input.size(0)):
for y in range(input.size(1)):
input[x,y] = input[x,y]/sumnorm[x]

return input
``````

If I now run my model, it does not crash, but seems to keep calculating something whitout progressing in the model.
Did I look over some important step? Or is it just so inefficient?
Any help is appreciated.

edit

The function appears to be running very,very slow.
Does somebody know how to speed it up?

Prefer not to use for loops, try to vectorize your code as much as possible.

Refer this old question I asked for checking my implementation of softargmax which I believe you are lookin for. It’s pretty decent and fast…

1 Like

Thank you very much the pointer, I was able to redo do the function and now it runs (almost) as fast as the original softmax! For reference here is the function (used for a 3D tensor and calculated over the rows):

``````    def forward(self, input):
m = 0
expm = math.exp(-m) #calculate m
input = torch.exp(input)

#keep track of the original dimensions
s0 = input.size(0)
s1 = input.size(1)
s2 = input.size(2)

#input = input.transpose(1,2).contiguous().view(-1,s1) #use if you want to calculate over columns instead of rows
input = input.view(-1,s2)

#calculate softmargin softmax
expsum = torch.sum(input, dim=1).view(-1,1) # sum and reshape for softmax
input = (input*expm)/(expsum - input + (input*expm))

#normalize
normsum = torch.sum(input, dim=1).view(-1,1) #sum and reshape for normalization
input = input/normsum

input = input.view(s0,s1,s2)
#input = input.view(s0,s2,s1).transpose(1,2).contiguous() #  #use if you want to calculate over columns instead of rows

return input
``````