Diag Matrix for images

Hey would you make a diagonal matrix with a size of 64x64x3?

torch.stack([torch.diagflat(numbers) for _ in range(3)],dim=2)

https://pytorch.org/docs/stable/generated/torch.diagflat.html?highlight=diag#torch.diagflat