Hi Granth!
The short answer is that you are calling python’s max()
function,
rather than pytorch’s torch.max()
tensor function. This is causing
you to calculate softmax()
for a tensor that is all zeros.
You have two issues:
First is the use of pytorch’s max()
. max()
doesn’t understand
tensors, and for reasons that have to do with the details of max()
's
implementation, this simply returns action_values
again (with the
singleton dimension removed).
The second is that there is no need to subtract a scalar from your
tensor before calling softmax()
. Any such scalar drops out anyway
in the softmax()
calculation.
This script illustrates what is going on:
import torch
torch.__version__
action_values = torch.tensor([[-0.4001, -0.2948, 0.1288]])
action_values
max (action_values) # this is python's max, not pytorch's
torch.max (action_values) # pytorch's tensor-version of max
action_values - max (action_values)
action_values - torch.max (action_values)
tzeros = torch.zeros ((1, 3))
tzeros
torch.nn.functional.softmax (tzeros, dim = 0)
torch.nn.functional.softmax (tzeros, dim = 1)
torch.nn.functional.softmax (action_values, dim = 1) # what you want
torch.nn.functional.softmax (action_values - 2.3, dim = 1) # shift drops out
Here is its output:
>>> import torch
>>> torch.__version__
'1.6.0'
>>> action_values = torch.tensor([[-0.4001, -0.2948, 0.1288]])
>>> action_values
tensor([[-0.4001, -0.2948, 0.1288]])
>>> max (action_values) # this is python's max, not pytorch's
tensor([-0.4001, -0.2948, 0.1288])
>>> torch.max (action_values) # pytorch's tensor-version of max
tensor(0.1288)
>>> action_values - max (action_values)
tensor([[0., 0., 0.]])
>>> action_values - torch.max (action_values)
tensor([[-0.5289, -0.4236, 0.0000]])
>>> tzeros = torch.zeros ((1, 3))
>>> tzeros
tensor([[0., 0., 0.]])
>>> torch.nn.functional.softmax (tzeros, dim = 0) tensor([[1., 1., 1.]])
>>> torch.nn.functional.softmax (tzeros, dim = 1) tensor([[0.3333, 0.3333, 0.3333]])
>>> torch.nn.functional.softmax (action_values, dim = 1) # what you want
tensor([[0.2626, 0.2918, 0.4456]])
>>> torch.nn.functional.softmax (action_values - 2.3, dim = 1) # shift drops out
tensor([[0.2626, 0.2918, 0.4456]])
Best.
K. Frank