In an attempt to implement the
log_prob() method for an
Empirical distribution function, I need to match the entries of a big tensor with samples. My naive implementation is terribly slow:
samples = torch.randn(10)
value = torch.randn(10000)
num_matches = torch.zeros_like(values)
for i, v in enumerate(value):
num_matches[i] = torch.isclose(samples, v).sum()
Does somebody have a suggestion, how I could speed it up?
You want to count how many samples each value is close to?
You can do
expanded_values = value.unsqueeze(1).expand(10000, 10)
expanded_samples = samples.unsqueeze(0).expand(10000, 10)
num_matches = torch.isclose(expanded_values, expanded_samples).sum(-1)
Perfect, thank you! I didn’t realize that torch.Tensor.expand creates a view on the existing tensor without allocating new memory. Very neat!
It does !
The result of
isclose will be quite large though in this case But if it fits in memory that is definitely going to be the fastest.