How to pick the diag values for selected dimensions?

Suppose I have a tensor of [10, 2, 2]. I want to pick the diag vector of dim 1 and 2, i.e., make the tensor to [10, 2]. Is there an easy way to do so? The .diag method doesn’t seem to work unless I iterate the 10 instances.

Compared with using for, I have a faster implementation.

import torch
import torch.nn as nn

n, k = 10, 2
x = torch.rand(n, k, k)
y = x.reshape(n, -1)[:, 0::(k + 1)]

for i in range(10):
    assert (y[i] == torch.diag(x[i])).all()
print("finish")

The following code will also work.

import torch
import torch.nn as nn

n, k = 10, 2
x = torch.rand(n, k, k)
y = x[torch.eye(k).repeat(10, 1, 1).bool()].reshape(n, -1)

for i in range(10):
    assert (y[i] == torch.diag(x[i])).all()
print("finish")