Indexing a batch of transition matrices with a block diagonal matrix

Hello,

I have a diagonal block matrix of shape NxN which is a diagonal stacking of adjacency matrices, one for each sample in the batch, and a batch of 2x2 matrices containing edge flipping probabilities (i.e. Q[0][1] contains the probability of flipping the entry of the adjacency matrix from 0 to 1 and so on).
I would like to obtain a NxNx2 matrix containing the appropriate probabilities.
I tried to give a visual explanation of the problem in the attached figure, what I would like to obtain is the analogous of Q[A] for the batch scenario.

If I understood your problem correctly, then you can do something like this

A = torch.tensor([[0,1,0],[1,0,1],[0,1,0]])
Q = torch.tensor([[0.75,0.25],  # Zero->Zero | Zero-> One
                  [0.1,0.9]])   # One ->Zero | One -> One

QA = Q[A, :]

print(f"Probability that the values are 0:\n\n{QA[:,:,0]}\n\n")
print(f"Probability that the values are 1:\n\n{QA[:,:,1]}")
# Output:
Probability that the values are 0:

tensor([[0.7500, 0.1000, 0.7500],
        [0.1000, 0.7500, 0.1000],
        [0.7500, 0.1000, 0.7500]])


Probability that the values are 1:

tensor([[0.2500, 0.9000, 0.2500],
        [0.9000, 0.2500, 0.9000],
        [0.2500, 0.9000, 0.2500]])

Hope this helps :smile:

Thank you for your answer Matias,

you understood the problem correctly, but that only solves the single sample scenario, i.e. the one depicted in the upper part of the figure.
I would like to do the same thing, but over a block matrix which is the result of diagonally stacking B adjacency matrices and using B different 2x2 transition matrices. Of course, I would like each b-th adjacency matrix in A_1, …, A_b to index the b-th transition matrix Q_b.

This answer is assuming that your matrix is square and the size of the patches are also square.
But I think the best option is to do it with a for loop (maybe there is a better option but idk)

A_n = torch.randint(0, 2, (4, 3, 3))

A = torch.block_diag(A_n[0], A_n[1], A_n[2], A_n[3])
print(A)

Q = torch.rand(4, 2, 2)

size_patch = 3
QA = torch.zeros(*A.shape, 2)

for i, idx in enumerate(range(0, A.shape[0], size_patch)):
    QA[idx:idx+size_patch, idx:idx+size_patch] = Q[i, A[idx:idx+size_patch, idx:idx+size_patch], :]

torch.set_printoptions(precision=2)
print(f"Probability that the values are 0:\n\n{QA[:,:,0]}\n\n")
print(f"Probability that the values are 1:\n\n{QA[:,:,1]}")

# Output:
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
Probability that the values are 0:

tensor([[0.99, 0.61, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.61, 0.61, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.99, 0.99, 0.99, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.96, 0.25, 0.96, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.96, 0.25, 0.96, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.96, 0.25, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.42, 0.42, 0.56, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.56, 0.42, 0.56, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.56, 0.42, 0.56, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.64, 0.64],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.64, 0.64],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.64, 0.96, 0.64]])


Probability that the values are 1:

tensor([[0.19, 0.47, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.47, 0.47, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.19, 0.19, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.29, 0.39, 0.29, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.29, 0.39, 0.29, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.29, 0.39, 0.39, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.98, 0.98, 0.29, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.29, 0.98, 0.29, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.29, 0.98, 0.29, 0.00, 0.00, 0.00],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.25, 0.25],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.25, 0.25],
        [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.74, 0.25]])

Thank you Matias,

I am using a similar approach to the one you suggested, but maybe there is a way to avoid the for loop.

You can then try applying all probabilities Q to the whole matrix A, stacking them into a new dimension and then selecting only the values of each dimension that are relevant to that block with a mask.

import torch

# This has to be done only ONCE
mask_0 = torch.zeros(A_n.shape)
mask_1 = torch.ones(A_n.shape)
mask_a = torch.block_diag(mask_1[0], mask_0[1], mask_0[2], mask_0[3])
mask_b = torch.block_diag(mask_0[0], mask_1[1], mask_0[2], mask_0[3])
mask_c = torch.block_diag(mask_0[0], mask_0[1], mask_1[2], mask_0[3])
mask_d = torch.block_diag(mask_0[0], mask_0[1], mask_0[2], mask_1[3])
Mask = torch.stack([mask_a, mask_b, mask_c, mask_d]) == 0
Mask = Mask.unsqueeze(-1).repeat(1, 1, 1, 2)

# Your diagonal matrix A
A_n = torch.randint(0, 2, (4, 3, 3))
A = torch.block_diag(A_n[0], A_n[1], A_n[2], A_n[3])

# Probability matrices Q
Q = torch.rand(4, 2, 2)

QA = Q[:, A, :]
QA[Mask] = 0
QA = torch.sum(QA, dim=0)

torch.set_printoptions(precision=2)
print(Q)
print("QA: ", QA[:, :, 0])
print("QA: ", QA[:, :, 1])

This might solve the problem, I’ll try it out!

Thank you :grinning: