Indexing a batch of transition matrices with a block diagonal matrix

Hello,

I have a diagonal block matrix of shape NxN which is a diagonal stacking of adjacency matrices, one for each sample in the batch, and a batch of 2x2 matrices containing edge flipping probabilities (i.e. Q[0][1] contains the probability of flipping the entry of the adjacency matrix from 0 to 1 and so on).
I would like to obtain a NxNx2 matrix containing the appropriate probabilities.
I tried to give a visual explanation of the problem in the attached figure, what I would like to obtain is the analogous of Q[A] for the batch scenario.

If I understood your problem correctly, then you can do something like this

``````A = torch.tensor([[0,1,0],[1,0,1],[0,1,0]])
Q = torch.tensor([[0.75,0.25],  # Zero->Zero | Zero-> One
[0.1,0.9]])   # One ->Zero | One -> One

QA = Q[A, :]

print(f"Probability that the values are 0:\n\n{QA[:,:,0]}\n\n")
print(f"Probability that the values are 1:\n\n{QA[:,:,1]}")
``````
``````# Output:
Probability that the values are 0:

tensor([[0.7500, 0.1000, 0.7500],
[0.1000, 0.7500, 0.1000],
[0.7500, 0.1000, 0.7500]])

Probability that the values are 1:

tensor([[0.2500, 0.9000, 0.2500],
[0.9000, 0.2500, 0.9000],
[0.2500, 0.9000, 0.2500]])
``````

Hope this helps

you understood the problem correctly, but that only solves the single sample scenario, i.e. the one depicted in the upper part of the figure.
I would like to do the same thing, but over a block matrix which is the result of diagonally stacking B adjacency matrices and using B different 2x2 transition matrices. Of course, I would like each b-th adjacency matrix in A_1, …, A_b to index the b-th transition matrix Q_b.

This answer is assuming that your matrix is square and the size of the patches are also square.
But I think the best option is to do it with a for loop (maybe there is a better option but idk)

``````A_n = torch.randint(0, 2, (4, 3, 3))

A = torch.block_diag(A_n[0], A_n[1], A_n[2], A_n[3])
print(A)

Q = torch.rand(4, 2, 2)

size_patch = 3
QA = torch.zeros(*A.shape, 2)

for i, idx in enumerate(range(0, A.shape[0], size_patch)):
QA[idx:idx+size_patch, idx:idx+size_patch] = Q[i, A[idx:idx+size_patch, idx:idx+size_patch], :]

torch.set_printoptions(precision=2)
print(f"Probability that the values are 0:\n\n{QA[:,:,0]}\n\n")
print(f"Probability that the values are 1:\n\n{QA[:,:,1]}")

``````
``````# Output:
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
Probability that the values are 0:

tensor([[0.99, 0.61, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.61, 0.61, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.99, 0.99, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.96, 0.25, 0.96, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.96, 0.25, 0.96, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.96, 0.25, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.42, 0.42, 0.56, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.56, 0.42, 0.56, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.56, 0.42, 0.56, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.64, 0.64],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.64, 0.64],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.96, 0.64]])

Probability that the values are 1:

tensor([[0.19, 0.47, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.47, 0.47, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.19, 0.19, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.29, 0.39, 0.29, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.29, 0.39, 0.29, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.29, 0.39, 0.39, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.98, 0.98, 0.29, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.29, 0.98, 0.29, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.29, 0.98, 0.29, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.25, 0.25],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.25, 0.25],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.74, 0.25]])
``````

Thank you Matias,

I am using a similar approach to the one you suggested, but maybe there is a way to avoid the for loop.

You can then try applying all probabilities `Q` to the whole matrix `A`, stacking them into a new dimension and then selecting only the values of each dimension that are relevant to that block with a mask.

``````import torch

# This has to be done only ONCE

A_n = torch.randint(0, 2, (4, 3, 3))
A = torch.block_diag(A_n[0], A_n[1], A_n[2], A_n[3])

# Probability matrices Q
Q = torch.rand(4, 2, 2)

QA = Q[:, A, :]