Let’s say i have tensor A of shape (1024) and tensor B of shape (5,20,1024) and tensor B contains tensor A.
How can i find index of tensor A in tensor B?
This should work:
a = torch.zeros(1024)
b = torch.randn(5, 20, 1024)
b[2, 17].fill_(0.)
idx = (a == b).all(dim=2).nonzero()