Hi all, I need a little bit of help figuring out where my implementation went wrong. I have an algorithm that runs on a 3D tensor, given by [batch, row, col]
. I used to use a for
loop to iterate through each n
, where n
is the number of rows/cols (n x n square matrix here). I figured out how to use the torch.scatter_add_
function to remove the for-loop.
Here’s the input to the function:
input = torch.Tensor([[[7, 8, 7, 20, 7],
[15, 2, 2, 4, 8],
[ 6, 17, 6, 7, 8],
[ 1, 6, 12, 6, 2],
[ 8, 1, 7, 2, 10]]]
Here’s the for-loop implementation:
def iter_simp_min_sum_batch(m_alpha_beta, m_beta_alpha, weights, n_iter):
for _ in range(n_iter):
# Message passing
for i in range(n):
alpha_beta_max = torch.max(torch.cat((m_alpha_beta[:, :i, :], m_alpha_beta[:, (i + 1):, :]), dim=1), dim=1)[0]
beta_alpha_max = torch.max(torch.cat((m_beta_alpha[:, :i, :], m_beta_alpha[:, (i + 1):, :]), dim=1), dim=1)[0]
if i == 0:
m_beta_alpha_k = (weights[:, i, :] - alpha_beta_max).unsqueeze(2)
m_alpha_beta_k = (weights[:, :, i] - beta_alpha_max).unsqueeze(2)
else:
m_beta_alpha_k = torch.cat((m_beta_alpha_k, (weights[:, i, :] - alpha_beta_max).unsqueeze(2)), dim=2)
m_alpha_beta_k = torch.cat((m_alpha_beta_k, (weights[:, :, i] - beta_alpha_max).unsqueeze(2)), dim=2)
m_alpha_beta = m_alpha_beta_k
m_beta_alpha = m_beta_alpha_k
return m_alpha_beta, m_beta_alpha
Here’s my scatter implementation of the above:
def iter_simp_min_sum_batch_scatter(m_alpha_beta, m_beta_alpha, weights, n_iter):
for _ in range(n_iter):
# Message passing
beta_alpha_maxes, beta_alpha_indices = torch.topk(m_beta_alpha, 2, dim=1)
print(beta_alpha_maxes[:, 0, :].requires_grad)
m_alpha_beta_k = weights.permute(0, 2, 1) - beta_alpha_maxes[:, 0, :].unsqueeze(1)
m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1), src=beta_alpha_maxes[:, 0, :].unsqueeze(1))
m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1), src=beta_alpha_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2, 1)
alpha_beta_maxes, alpha_beta_indices = torch.topk(m_alpha_beta, 2, dim=1)
m_beta_alpha_k = weights - alpha_beta_maxes[:, 0, :].unsqueeze(1)
m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1), src=alpha_beta_maxes[:, 0, :].unsqueeze(1))
m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1), src=alpha_beta_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2, 1)
m_alpha_beta = m_alpha_beta_k
m_beta_alpha = m_beta_alpha_k
return m_alpha_beta, m_beta_alpha
Finally, here’s how I test the above implementations:
# Scatter implementation
a = torch.nn.Parameter(input)
scatter_b_a, scatter_b_b = iter_simp_min_sum_batch_scatter(a, a.permute(0, 2, 1), a.clone().detach(), 3)
scatter_b_b.backward(gradient=torch.ones((scatter_b_b.size(0), scatter_b_b.size(1), scatter_b_b.size(2))))
# For-loop implementation
b = torch.nn.Parameter(input)
iter_b_a, iter_b_b = iter_simp_min_sum_batch(b, b.permute(0, 2, 1), b.clone().detach(), 3)
iter_b_b.backward(gradient=torch.ones((iter_b_b.size(0), iter_b_b.size(1), iter_b_b.size(2))))
Here’s the results:
a.grad
tensor([[[ 0., 0., -2., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., -2.],
[ 0., 0., -4., 0., 0.],
[ -1., 0., 0., 0., -16.]]])
b.grad
tensor([[[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., -2.],
[ 0., 0., -4., 0., 0.],
[ -1., 0., -2., 0., -16.]]])
scatter_b_b - iter_b_b
tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], grad_fn=<SubBackward0>)
As shown in scatter_b_b - iter_b_b
, the solutions are the same. The difference is in the gradients which I don’t really understand where the difference is.
Thanks a lot for any help!