Bad behavior of `.multinomial()` function?

Hello,

I’m working on a model that has dynamic masking to avoid resample actions already taken.
Tha mask function is:

def apply_mask( attentions, mask, prev_idxs):    
    if mask is None:
        mask = torch.zeros(attentions.size()).byte().cuda()          

    maskk = mask.clone()

    if prev_idxs is not None:

        for i,j in zip(range(attentions.size(0)),prev_idxs.data):
            maskk[i,j[0]] = 1

        attentions[maskk] = -np.inf

    return attentions, maskk

When I apply the .multinomial() to the probabilities, it occasionally samples actions with zero probability.
For example, when I run the following function:

def count(n):
    k = 0
    for j in range(n):

        attentions = Variable(torch.Tensor(128,50).uniform_(-10, 10).cuda())
        prev_actions = None
        mask = None
        actions = []
        for di in range(50):
                attentions, mask = apply_mask( attentions, mask, prev_actions)
                probs = F.softmax(attentions).cuda()
                prev_actions = probs.multinomial()
                for old_idxs in actions:
                    # compare new idxs
                    if old_idxs.eq(prev_actions).data.any():
                        k+=1
                        print(' [!] resampling')

                actions.append(prev_actions)
    return k

I obtain a relative frequency of 0.00043 of these bad samples on a 100000 run.

Is there a problem with the .multinomial() function or there is a better way to apply the mask?

Thank you in advance.

I ran your code, and I’m unable to replicate your results (nothing was resampled).

Could you please try some of the following: (it would help debugging)

  • building from source and testing again
  • running on CPU and seeing if the problem still persists
  • running the script below and seeing what happens
import torch
dist = torch.zeros(1000).cuda()
dist[0] = 1
x = torch.multinomial.sample(dist, 1000, True)
sum(x)  # should equal 0

If it is any help, I looked into it and isolated the issue a bit more. It happens consistently (on GPU) and is reproducible, on two different machines, Pytorch 0.3 and Cuda 9.0/9.1. I found that it depends on on the range of the logits as well, if I change logits_range to 1 or 100 it does not happen. Saving the random state I can trigger the incorrect sampling immediately.

import torch
from torch.autograd import Variable

import torch.nn.functional as F
import numpy as np
from tqdm import tqdm


def test(n, hot=False, logits_range=10):
    torch.manual_seed(1234)

    logits = Variable(torch.Tensor(128, 50).uniform_(-logits_range, logits_range).cuda())

    # Set randomly 40 elements per row to 0
    mask = torch.zeros_like(logits).byte()
    _, idx_mask = Variable(torch.Tensor(128, 50).uniform_(0, 1).cuda()).topk(40, 1)
    mask.scatter_(1, idx_mask, True)

    logits[mask] = -np.inf

    probs = F.softmax(logits, dim=1)

    assert (probs[mask] == 0).all()
    assert (torch.abs(probs.sum(1) - 1) < 1e-6).all()

    if hot:
        with open('rng_state.pt', 'rb') as f:
            rng_state = torch.load(f)
        torch.cuda.set_rng_state(rng_state)

    for j in tqdm(range(n)):

        rng_state = torch.cuda.get_rng_state()

        sample = probs.multinomial(1).squeeze(-1)
        mask_sample = mask.gather(1, sample.unsqueeze(-1)).squeeze(-1)

        if mask_sample.any():
            print("Sampled value that was masked and had probability 0 in iteration {}".format(j))
            wrong = torch.nonzero(mask_sample).squeeze(-1)
            print("Wrong samples: indices {}, sampled {}, probs {}".format(
                wrong.data.cpu().numpy().tolist(),
                sample[wrong].data.cpu().numpy().tolist(),
                probs[wrong, sample[wrong]].data.cpu().numpy().tolist()
            ))

            if hot:
                break

            with open('rng_state.pt', 'wb') as f:
                torch.save(rng_state, f)


if __name__ == "__main__":
    with torch.cuda.device(0):
        test(100000, hot=False)

Thanks for the details, @wouter. There’s an issue open for this here, it might help to post there as well.

Thanks, I did not know of that issue!

I ran into the same issue today as well.
I see there was a fix. Is my only chance to get that fix to recompile from source at the moment?

A fix was included in PyTorch 0.4, but unfortunately the problem is not fully fixed. See discussion here: