I’m working on a model that has dynamic masking to avoid resample actions already taken.
Tha mask function is:
def apply_mask( attentions, mask, prev_idxs):
if mask is None:
mask = torch.zeros(attentions.size()).byte().cuda()
maskk = mask.clone()
if prev_idxs is not None:
for i,j in zip(range(attentions.size(0)),prev_idxs.data):
maskk[i,j[0]] = 1
attentions[maskk] = -np.inf
return attentions, maskk
When I apply the .multinomial() to the probabilities, it occasionally samples actions with zero probability.
For example, when I run the following function:
def count(n):
k = 0
for j in range(n):
attentions = Variable(torch.Tensor(128,50).uniform_(-10, 10).cuda())
prev_actions = None
mask = None
actions = []
for di in range(50):
attentions, mask = apply_mask( attentions, mask, prev_actions)
probs = F.softmax(attentions).cuda()
prev_actions = probs.multinomial()
for old_idxs in actions:
# compare new idxs
if old_idxs.eq(prev_actions).data.any():
k+=1
print(' [!] resampling')
actions.append(prev_actions)
return k
I obtain a relative frequency of 0.00043 of these bad samples on a 100000 run.
Is there a problem with the .multinomial() function or there is a better way to apply the mask?
If it is any help, I looked into it and isolated the issue a bit more. It happens consistently (on GPU) and is reproducible, on two different machines, Pytorch 0.3 and Cuda 9.0/9.1. I found that it depends on on the range of the logits as well, if I change logits_range to 1 or 100 it does not happen. Saving the random state I can trigger the incorrect sampling immediately.
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
def test(n, hot=False, logits_range=10):
torch.manual_seed(1234)
logits = Variable(torch.Tensor(128, 50).uniform_(-logits_range, logits_range).cuda())
# Set randomly 40 elements per row to 0
mask = torch.zeros_like(logits).byte()
_, idx_mask = Variable(torch.Tensor(128, 50).uniform_(0, 1).cuda()).topk(40, 1)
mask.scatter_(1, idx_mask, True)
logits[mask] = -np.inf
probs = F.softmax(logits, dim=1)
assert (probs[mask] == 0).all()
assert (torch.abs(probs.sum(1) - 1) < 1e-6).all()
if hot:
with open('rng_state.pt', 'rb') as f:
rng_state = torch.load(f)
torch.cuda.set_rng_state(rng_state)
for j in tqdm(range(n)):
rng_state = torch.cuda.get_rng_state()
sample = probs.multinomial(1).squeeze(-1)
mask_sample = mask.gather(1, sample.unsqueeze(-1)).squeeze(-1)
if mask_sample.any():
print("Sampled value that was masked and had probability 0 in iteration {}".format(j))
wrong = torch.nonzero(mask_sample).squeeze(-1)
print("Wrong samples: indices {}, sampled {}, probs {}".format(
wrong.data.cpu().numpy().tolist(),
sample[wrong].data.cpu().numpy().tolist(),
probs[wrong, sample[wrong]].data.cpu().numpy().tolist()
))
if hot:
break
with open('rng_state.pt', 'wb') as f:
torch.save(rng_state, f)
if __name__ == "__main__":
with torch.cuda.device(0):
test(100000, hot=False)