Hi all, I am curious why the following code will return false. Slice seems to affect the linear layer output.
torch.manual_seed(1234)
a = torch.randn((50, 4096)).float()
idx = [0, 2]
b = a[idx,:]
w1 = torch.nn.Linear(4096, 4096, bias=False)
w2 = torch.nn.Linear(4096, 4096, bias=False)
w3 = torch.nn.Linear(4096, 4096, bias=False)
act = torch.nn.SiLU()
out_a = w3(act(w1(a)) * w2(a))
out_b = w3(act(w1(b)) * w2(b))
print(torch.equal(out_a[idx,:], out_b))