Custom Softmax Function


I wanted to define a custom softmax function, for example, with a temperature term.
I was not sure where to start. Can I just define a function, like this example? (another thread):

def trucated_gaussian(x, mean=0, std=1, min=0.1, max=0.9):
    gauss = torch.exp((-(x - mean) ** 2)/(2* std ** 2))
    return torch.clamp(gauss, min=min, max=max) # truncate

And use the output instead of the standard softmax? or does it need to be a nn module?


it needs to be a class that inherits from torch.autograd.Funtion, you needs to define forward and backward