I have the following incredibly strange bug with isin and assume_unique=True.
Without compiling, the code runs fine for any n
and assume_unique
value.
When compiling, the code runs fine for any n
when assume_unique=False
. When assume_unique=True
, the code runs fine for n < 12
, and errors for n >= 12
.
import torch
from torch import nn, Tensor
symbols = 2
bars = 3
all_false = torch.full((symbols, bars), False, device="cuda") # (symbols, bars)
@torch.compile
class MyModel(nn.Module):
def __init__(self, n: int, assume_unique: bool) -> None:
super().__init__()
self.ignored_symbols = torch.arange(n, device="cuda")
self.assume_unique = assume_unique
def forward(
self,
symbol_ids: Tensor, # (symbols, 1)
) -> Tensor:
mask_ignored_symbols = torch.isin(
symbol_ids, self.ignored_symbols, assume_unique=self.assume_unique
) # (symbols, 1)
return all_false | mask_ignored_symbols
# +50 to avoid the ignored symbols, not important
symbol_ids = torch.arange(symbols, device="cuda").unsqueeze(1) + 50
print(symbol_ids)
model = MyModel(n=12, assume_unique=True).cuda()
output: Tensor = model(symbol_ids)
The following error is what appears:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Attempting to broadcast a dimension of length 2 at -1! Mismatching argument at index 1 had torch.Size([2]); but expected shape should be broadcastable to [2, 3]
While executing %or_ : [num_users=1] = call_function[target=operator.or_](args = (%all_false, %mask_ignored_symbols), kwargs = {})
Original traceback:
File "/home/ilan/monorepo/src/Ayin3/foo.py", line 28, in forward
return all_false | mask_ignored_symbols
I guess there seems to be some weird squeezing going on when using assume_unique=True
in compile.