C++ API for broadcasting on certain axes

Hi! :wave:

We are writing libtorch bindings for Elixir’s Nx. The initial design was built on top of XLA and one of the operations XLA provides is BroadcastInDim, which accepts an explicit mapping of the axes we want to broadcast.

Here is an example:

    iex> t = Nx.tensor([1, 2, 3])
    iex> Nx.broadcast(t, {2, 3, 2}, axes: [1])
    #Nx.Tensor<
      s64[2][3][2]
      [
        [
          [1, 1],
          [2, 2],
          [3, 3]
        ],
        [
          [1, 1],
          [2, 2],
          [3, 3] 
        ]
      ]
    >

In the example above, we are trying to broadcast a tensor of shape (3) to shape (2, 3, 2). Using the default broadcast semantics, which follows Numpy, this would not be possible, because the lower dimensions do not match (2 != 3). However, I can pass the axes option, that maps each axis of the broadcasting tensor to the shape. In this case, I am mapping the only axis of the tensor (axis 0) to the axis 1 of the shape, allowing the tensor to be broadcast to (2, 3, 2).

The custom broadcasting rules come in hand on some of the autograd operations. We can mimic the behaviour using broadcast+transpose but I wonder if there is a preferred way to achieve this (and if we could be missing some optimizations by going the broadcast+transpose route).

Thank you!

The only cost I can think of has to do with the overhead of the extra operations (e.g. transpose [or unsqueeze], plus normal broadcast, plus transpose back)

1 Like