Hi all,
I’m trying to write a function that computes the Fast Walsh Hadamard transform using ATen, at some point I have a few lines that make use of advanced indexing. I’ve only found this Github issue regarding advanced indexing in ATen: https://github.com/zdevito/ATen/issues/78
I have zero experience in this language and https://pytorch.org/cppdocs/ is not of much help for the moment.
My Python code looks like this:
temp = torch.zeros((N_samples, N // 2, 2), device=x.device) # very important, have to
# initialize the new tensors on the used device
temp[:, :, 0] = x[:, 0::2] + x[:, 1::2]
temp[:, :, 1] = x[:, 0::2] - x[:, 1::2]
res = torch.tensor(temp, device=x.device)
# Second and further stage
for nStage in range(2, int(log(N, 2)) + 1):
temp = torch.zeros((N_samples, G // 2, M * 2), device=x.device)
temp[:, 0:G // 2, 0:M * 2:4] = res[:, 0:G:2, 0:M:2] + res[:, 1:G:2, 0:M:2]
temp[:, 0:G // 2, 1:M * 2:4] = res[:, 0:G:2, 0:M:2] - res[:, 1:G:2, 0:M:2]
temp[:, 0:G // 2, 2:M * 2:4] = res[:, 0:G:2, 1:M:2] - res[:, 1:G:2, 1:M:2]
temp[:, 0:G // 2, 3:M * 2:4] = res[:, 0:G:2, 1:M:2] + res[:, 1:G:2, 1:M:2]
res = torch.tensor(temp, device=x.device)
G = G // 2
M = M * 2
res = temp[:, 0, :]
How do I handle this kind of indexing in ATen?
Thanks in advance for your help