Hi, I need a bit of help replacing a for-loop. Here’s a code snippet

```
for _ in range(n_iter):
# Message passing
for i in range(n):
m_beta_alpha_k[:, :, i] = g[:, i, :] - torch.max(
torch.cat((m_alpha_beta[:, :i, :], m_alpha_beta[:, (i + 1):, :]), dim=1), dim=1)[0]
m_alpha_beta_k[:, :, i] = g[:, :, i] - torch.max(
torch.cat((m_beta_alpha[:, :i, :], m_beta_alpha[:, (i + 1):, :]), dim=1), dim=1)[0]
m_alpha_beta = m_alpha_beta_k
m_beta_alpha = m_beta_alpha_k
```

Each tensor is a 3D tensor, where we have [batch, i, j], where [i, j] forms a matrix and we have a batch size of them to operate on. As shown in the loop, at each iteration, I wish to exclude a row of the [i, j] matrix, and take a `max`

operation in each column for the remaining matrix. Is there a way to do this efficiently in torch?