# How to vectorize this pytorch code over (at least) the batch dimension?

``````A = torch.zeros((3, 3, 3), dtype = torch.float)
X = torch.tensor([[0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1]])
for a, x in zip(A, X):
for i, j in zip(x, x[1:]):
a[i, j] = 1
``````

Thanks! Hi,

When you say the batch dimension, is it the first dimension in your example? Or another dimension to add to your example?
Also the formula that you want is the following?
`forall i in A.size(0), forall j in x.size(0) - 1: A[i, x[i, j], x[i, j+1]] = 1`
Meaning that every pair of consecutive values in x are the the indices where you should put a 1 ?

1 Like

Hi! That’s right! The batch dimension is the first.

Just to be sure before writing code for this (to make sure I don’t write code that does not match what you want), what is the rational for the consecutive numbers in x being used as coordinate? Why each number is first an index on the 2nd dimension then an index on the 1st dimension?
It would fell more natural to have half being indices in the 1st dimension and the other half being indices in the 2nd dimension. Or at least have them used only once, not twice.

1 Like

Consider X : [0, 1, 2, 0, 1, 0]. What I want to build an adjacency matrix such that:

A[0, 1] = 1
A[1, 2] = 1
A[2, 0] = 1
A[0, 1] = 1
A[1, 0] = 1

The code I shared works, however it’s very slow since I’m looping over each element in the batch. I’d like to vectorize it as much as possible. Thanks! There you go, this will be much faster Note that I use `.narrow(-1, 0, x_size-1)` out of habit but `[:, :-1]` works as well if you prefer that notation.

``````import torch

A = torch.zeros((3, 3, 3), dtype = torch.float)
X = torch.tensor([[0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1]])
for a, x in zip(A, X):
for i, j in zip(x, x[1:]):
a[i, j] = 1

print(A)

A = torch.zeros((3, 3, 3), dtype = torch.float)
# This code assumes A is contiguous ! If it is not, add
# A = A.contiguous()
# For indexing, collapse the last two dimensions of A
A_view = A.view(A.size(0), -1)
# Compute the indices where you will index in A
x_size = X.size(-1)
indices = X.narrow(-1, 0, x_size-1) * A.stride(1) * A.stride(2) + X.narrow(-1, 1, x_size-1) * A.stride(2)
# Put 1s at the computed indices
A_view.scatter_(1, indices, 1)

print(A)

``````
1 Like

Thanks! works like a charm 