Is there a GPU equivalent hash implementation for the Tensor wrapper bellow? I want sto store unique multi-dimensional tensors in a set.
"""Wrapper to store tensors in a set."""
import torch
class HashTensor:
def __init__(self, obj):
self.obj = obj
def __hash__(self):
return hash(self.obj.cpu().numpy().tobytes())
def __eq__(self, other):
if isinstance(other, HashTensor):
return torch.equal(self.obj, other.obj)
elif isinstance(other, torch.Tensor):
return torch.equal(self.obj, other)
return False
def __repr__(self):
return repr(self.obj)