How to put tensors in a set?

I would like to use a python set to check if I have seen a given tensor before, as a termination condition. Of course, this doesn’t work as tensors are only equal at that level if they are the same object. Will I have to write my own implementation to cast tensors into something I can put in a set? I get the feeling that moving everything to cpu, for example as a tuple, is not the nicest way to do this.

In general this is tricky for a few different reasons such as what you mean by equality. Do you mean perfect bitwise equality in the underlying data, or some kind of fuzzy floating point “equality” like what allclose is used for?

Thank you for such a quick reply. I am using integers (the problem is checking if an RL policy has been seen before, in case you are familiar with this). So exact bitwise equality is indeed what I am looking for.

In that case, does something like a simple wrapper class work for your use case?

import torch

class HashTensorWrapper():
    def __init__(self, tensor):
        self.tensor = tensor

    def __hash__(self):
        return hash(self.tensor.numpy().tobytes())

    def __eq__(self, other):
        return torch.all(self.tensor == other.tensor)

a = torch.randn(1000)
b = a.clone()
print(hash(a) == hash(b))
a_wrap = HashTensorWrapper(a)
b_wrap = HashTensorWrapper(b)
print(hash(a_wrap) == hash(b_wrap))

unwrapped_set = set()
unwrapped_set.add(a)
unwrapped_set.add(b)
wrapped_set = set()
wrapped_set.add(a_wrap)
wrapped_set.add(b_wrap)
print(len(unwrapped_set), len(wrapped_set))
$ python3 hash.py
False
True
2 1
1 Like

Yes, this would work. I will time this against naive solution (just move onto cpu as pytorch tensor → numpy array → tuple). But I think we still have the same problem that the bytes will go onto the cpu. To get around this, somehow we would need to implement a hashset on gpu. I doubt this is done though, haha. But this also poses an interesting idea: Maybe searching a hashtable is (much?) faster on gpu.

Actually, I found that tensorflow has an implementation of hashtable: tf.lookup.experimental.DenseHashTable  |  TensorFlow Core v2.5.0. Is there any sort of feature request system, so that someday this can exist for pytorch as well?

You could consider opening an issue on the github for this. I am curious if there is a workaround that works like computing a cruder hash function using functions that are native on GPU so that only a very small amount of data has to be copied to CPU for hashing:

import time
import torch

class HashTensorWrapper():
    def __init__(self, tensor):
        self.tensor = tensor

    def __hash__(self):
        return hash(self.tensor.cpu().numpy().tobytes())

    def __eq__(self, other):
        return torch.all(self.tensor == other.tensor)

class HashTensorWrapper2():
    def __init__(self, tensor):
        self.tensor = tensor
        self.hashcrap = torch.arange(self.tensor.numel(), device=self.tensor.device).reshape(self.tensor.size())

    def __hash__(self):
        if self.hashcrap.size() != self.tensor.size():
            self.hashcrap = torch.arange(self.tensor.numel(), device=self.tensor.device).reshape(self.tensor.size())
        return hash(torch.sum(self.tensor*self.hashcrap))

    def __eq__(self, other):
        return torch.all(self.tensor == other.tensor)


a = torch.randn(1000,1000).cuda()
b = a.clone()
print(hash(a) == hash(b))
a_wrap = HashTensorWrapper(a)
b_wrap = HashTensorWrapper(b)
a_wrap2 = HashTensorWrapper2(a)
b_wrap2 = HashTensorWrapper2(b)
print(hash(a_wrap2) == hash(b_wrap2))

unwrapped_set = set()
unwrapped_set.add(a)
unwrapped_set.add(b)
wrapped_set = set()
wrapped_set.add(a_wrap2)
wrapped_set.add(b_wrap2)
print(len(unwrapped_set), len(wrapped_set))
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
    hash(a_wrap)
torch.cuda.synchronize()
t2 = time.time()
torch.cuda.synchronize()
t3 = time.time()
for i in range(10):
    hash(a_wrap2)
torch.cuda.synchronize()
t4 = time.time()
print(t2-t1, t4-t3)
# python hash.py
False
True
2 1
0.027219772338867188 0.0004017353057861328
1 Like

Neat trick! I will probably use that, and also open an issue to request the feature. Thank you very much for your help and implementations.

1 Like