How to create a tensor of masks where only one entry is 1

Say we have a tensor T of size [s, s].
How do I create a mask tensor of size [s*s, s, s] where for each tensor only 1 entry is equal to 1.
E.g for s = 3 mask tensor would look like

[
[[1, 0, 0], [0,0,0], [0, 0, 0]],
[[0, 1, 0], [0,0,0], [0, 0, 0]],
[[0, 0, 1], [0,0,0], [0, 0, 0]],
...
[[0, 0, 0], [0,0,0], [0, 0, 1]]
]

Thanks!

mask_setup = torch.ones(s, s) #Shape -> [s, s]

mask = torch.diag(mask_setup) #Shape -> [s*s, s*s]

mask = mask.reshape(s*s, s, s) #Shape -> [s*s, s, s]

This should give you what you need. There may be a simpler way to get it done though.

So in my case s = 14

mask = mask.reshape(s * s, s, s)
RuntimeError: shape '[196, 14, 14]' is invalid for input of size 14

mask = torch.diag(mask_setup)

– > torch.Size([14])

so it does not work, but maybe you can explain me your idea?

I apologize. The first line is wrong.

mask_setup = torch.ones(s*s,) #Shape -> [s*s,]

mask = torch.diag(mask_setup) #Shape -> [s*s, s*s]

mask = mask.reshape(s*s, s, s) #Shape -> [s*s, s, s]

The idea is to create a diagonal matrix of size (s * s, s * s) and then reshape it (s*s, s, s).

1 Like