# Should I call backward() separately on each stochastic node?

I have a model which has a single sample from a multinomial pretty far upstream. Other than that it’s a standard supervised learning model, resulting in a single loss value at the end.

I’ve got my head around the reinforce function. I now do the following:

• Compute a loss vector L over the batch (i.e. the batch loss, but not summed)
• Pass the negative of this loss to reinforce() on the Variable returned by torch.multinomial().
• Call backward() on the sum of the loss.
• Call optimizer.step()

Doing this, the gradient over the parameters of the multinomial is None. Only if I call backward() also on the output of torch.multinomial() (after calling reinforce()) do I get a gradient. Is this the correct approach, or am I misunderstanding something?

Thats correct. but personally, I just do the calcs myself, since:

2. in next torch version, the api is very similar to doing it by hand, so migraiton will be easier
3. makes it easier for non-pytorch people to read too

The calcs by hand, compared with reinforce, is something like:

``````import torch
from torch import autograd, nn, optim
import torch.nn.functional as F
import numpy as np

def run_model(x, h1):
torch.manual_seed(123)
params = list(h1.parameters())

logits = h1(Variable(x))
probs = F.softmax(logits)
a = torch.multinomial(probs)
print('a', a)
return logits, probs, a

def get_r():
torch.manual_seed(123)
r = torch.rand(4, 1)
return r

def run_by_hand(params, x, h1):
print('')
print('=======')
print('by hand')

logits, probs, a = run_model(x, h1)
g = torch.gather(probs, 1, Variable(a.data))
log_g = g.log()
a = a.data

r = get_r()
r_loss = - (log_g * Variable(r)).sum()
r_loss.backward()

def run_pytorch_reinforce(params, x, h1):
print('')
print('=======')
print('pytorch reinforce')
x1, x2, a = run_model(x, h1)
r = get_r()
a.reinforce(r)

def run():
N = 4
K = 1
C = 3

torch.manual_seed(123)
x = torch.ones(N, K)
h1 = nn.Linear(K, C, bias=False)
params = list(h1.parameters())[0]
run_by_hand(params, x, h1)
run_pytorch_reinforce(params, x, h1)

if __name__ == '__main__':
run()``````