Ardeal
(Ardeal)
May 9, 2022, 8:58am
1
Hi
For the following code,
shape of outputs
is: torch.Size([2205, 7])
shape of labels
is: torch.Size([2205])
I am experiencing the error:
The shape of the mask [2205] at index 0 does not match the shape of the indexed tensor [2205, 7] at index 1
t0 = torch.zeros_like(outputs)
t0[range(outputs.size()[0]), labels] = 1
Can you post how labels is defined?
I tried to reproduce the error with this code, but it works.
outputs = torch.rand(2205, 7)
labels = torch.randint(0, 7, (2205,))
t0 = torch.zeros_like(outputs)
t0[range(outputs.size(0)), labels] = 1
Ardeal
(Ardeal)
May 9, 2022, 9:26am
3
labels is a tensor,
dtype = torch.uint8
min value in labels is 0, max value in labels is 6
It seems to be the type that is causing the error.
When defining labels = torch.randint(0, 7, (2205,), dtype=torch.uint8)
I get the same error as you.
A possible workaround would be to cast it as long
.
This ↓ should work.
outputs = torch.rand(2205, 7)
labels = torch.randint(0, 7, (2205,), dtype=torch.uint8)
t0 = torch.zeros_like(outputs)
t0[range(outputs.size(0)), labels.long()] = 1
Ardeal
(Ardeal)
May 9, 2022, 9:30am
5
Matias_Vasquez:
labels.long()
Thank you!
This error is very weird!
the index torch.uint8 and torch.long are different???
Apparently uint8
is used for masking and int64
for indexing.
I believe this means that if you have 3
for example, which int64
it would understand that you need position number 4
starting from 0
. But with uint8
you would get a mask like this 0000 0011
. (But I might be wrong).
opened 01:54PM - 06 Jan 22 UTC
triaged
module: python array api
Integer, scalar tensors should behave like integers when used as index. Tensors … of dtype `torch.uint8` deviate from that:
```python
import torch
t_1d_single = torch.empty(1)
t_1d_multi = torch.empty(2)
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
print("single", dtype, t_1d_single[torch.tensor(0, dtype=dtype)].shape)
print("multi1", dtype, t_1d_multi[torch.tensor(0, dtype=dtype)].shape)
print("multi2", dtype, t_1d_multi[torch.tensor(1, dtype=dtype)].shape)
print("#" * 50)
```
```
single torch.uint8 torch.Size([0, 1])
multi1 torch.uint8 torch.Size([0, 2])
multi2 torch.uint8 torch.Size([1, 2])
##################################################
single torch.int8 torch.Size([])
multi1 torch.int8 torch.Size([])
multi2 torch.int8 torch.Size([])
##################################################
single torch.int16 torch.Size([])
multi1 torch.int16 torch.Size([])
multi2 torch.int16 torch.Size([])
##################################################
single torch.int32 torch.Size([])
multi1 torch.int32 torch.Size([])
multi2 torch.int32 torch.Size([])
##################################################
single torch.int64 torch.Size([])
multi1 torch.int64 torch.Size([])
multi2 torch.int64 torch.Size([])
##################################################
```
cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi
opened 07:41AM - 21 Feb 19 UTC
closed 08:39PM - 21 Feb 19 UTC
## 🐛 Bug
## To Reproduce
Steps to reproduce the behavior:
import to… rch
number = torch.randn(6,3,11)
p = torch.arange(10)
print(0,0,p[0],number[0,0,p[0]])
p = torch.arange(10,dtype = torch.uint8)
print(0,0,p[0],number[0,0,p[0]])
print(0,0,p[0],number[0,0,int(p[0])])
output:
0 0 tensor(0) **tensor(-0.2130)**
0 0 tensor(0, dtype=torch.uint8) **tensor([], size=(0, 11))**
0 0 tensor(0, dtype=torch.uint8) **tensor(-0.2130)**
torch.uint8 is not OK for index!
## Expected behavior
## Environment
Please copy and paste the output from our
[environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py)
(or fill out the checklist below manually).
You can get the script and run it with:
```
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
```
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (`conda`, `pip`, source):
- Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
## Additional context
Ardeal
(Ardeal)
May 9, 2022, 9:55am
7
1 Like