I wanted to define a custom softmax function, for example, with a temperature term.
I was not sure where to start. Can I just define a function, like this example? (another thread):
def trucated_gaussian(x, mean=0, std=1, min=0.1, max=0.9): gauss = torch.exp((-(x - mean) ** 2)/(2* std ** 2)) return torch.clamp(gauss, min=min, max=max) # truncate
And use the output instead of the standard softmax? or does it need to be a nn module?