Let’s say i have tensor A of shape (1024) and tensor B of shape (5,20,1024) and tensor B contains tensor A.

How can i find index of tensor A in tensor B?

This should work:

```
a = torch.zeros(1024)
b = torch.randn(5, 20, 1024)
b[2, 17].fill_(0.)
idx = (a == b).all(dim=2).nonzero()
```