Why slice may affect the nn.linear output?

Hi all, I am curious why the following code will return false. Slice seems to affect the linear layer output.

torch.manual_seed(1234)
a = torch.randn((50, 4096)).float()
idx = [0, 2]
b = a[idx,:]
w1 = torch.nn.Linear(4096, 4096, bias=False)
w2 = torch.nn.Linear(4096, 4096, bias=False)
w3 = torch.nn.Linear(4096, 4096, bias=False)

act = torch.nn.SiLU()
out_a = w3(act(w1(a)) * w2(a))
out_b = w3(act(w1(b)) * w2(b))
print(torch.equal(out_a[idx,:], out_b))

Changing the input shape, stride, and memory layout can create different results caused by the limited floating point precision and a different order of operations. Also, different mathematical algorithms can be selected based on the aforementioned metadata of the tensors.