Should I call backward() separately on each stochastic node?

I have a model which has a single sample from a multinomial pretty far upstream. Other than that it’s a standard supervised learning model, resulting in a single loss value at the end.

I’ve got my head around the reinforce function. I now do the following:

  • Compute a loss vector L over the batch (i.e. the batch loss, but not summed)
  • Pass the negative of this loss to reinforce() on the Variable returned by torch.multinomial().
  • Call backward() on the sum of the loss.
  • Call optimizer.step()

Doing this, the gradient over the parameters of the multinomial is None. Only if I call backward() also on the output of torch.multinomial() (after calling reinforce()) do I get a gradient. Is this the correct approach, or am I misunderstanding something?

Thats correct. but personally, I just do the calcs myself, since:

  1. then I dont have to think about this so much: it resembles REINFORCE more
  2. in next torch version, the api is very similar to doing it by hand, so migraiton will be easier
  3. makes it easier for non-pytorch people to read too

The calcs by hand, compared with reinforce, is something like:

import torch
from torch import autograd, nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


def run_model(x, h1):
    torch.manual_seed(123)
    params = list(h1.parameters())

    logits = h1(Variable(x))
    probs = F.softmax(logits)
    a = torch.multinomial(probs)
    print('a', a)
    return logits, probs, a


def get_r():
    torch.manual_seed(123)
    r = torch.rand(4, 1)
    return r


def run_by_hand(params, x, h1):
    print('')
    print('=======')
    print('by hand')
    h1.zero_grad()

    logits, probs, a = run_model(x, h1)
    g = torch.gather(probs, 1, Variable(a.data))
    log_g = g.log()
    a = a.data

    r = get_r()
    r_loss = - (log_g * Variable(r)).sum()
    r_loss.backward()
    print('params.grad', params.grad)


def run_pytorch_reinforce(params, x, h1):
    print('')
    print('=======')
    print('pytorch reinforce')
    h1.zero_grad()
    x1, x2, a = run_model(x, h1)
    r = get_r()
    a.reinforce(r)
    autograd.backward([a], [None])
    print('params.grad', params.grad)


def run():
    N = 4
    K = 1
    C = 3

    torch.manual_seed(123)
    x = torch.ones(N, K)
    h1 = nn.Linear(K, C, bias=False)
    params = list(h1.parameters())[0]
    run_by_hand(params, x, h1)
    run_pytorch_reinforce(params, x, h1)


if __name__ == '__main__':
    run()