I am trying to rewrite the index_put
op with scatter
, in this procedure, I need to use some shape information of input tensors:
def index_put_pattern(data, indices, values, accumulate):
return torch.ops.aten.index_put(data, indices, values, accumulate)
def index_put_replacement(data, indices, values, accumulate):
index = indices[0]
Q, P = index.shape
_, N = values.shape
bidx = torch.broadcast_to(index.reshape(Q, P, 1), (Q, P, N)).reshape(Q * P, N)
values = torch.broadcast_to(values.reshape(1, P, N), (Q, P, N)).reshape(Q * P, N)
return data.scatter_(0, bidx, values, reduce="add")
But error occurs:
bidx = torch.broadcast_to(index.reshape(Q, P, 1), (Q, P, N)).reshape(Q * P, N)
TypeError: broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found element of type Proxy at pos 0
Any solution for this?