About batched effect in backpropagation of matmul

How does bp work in batched matmul scene? assume we have a linear, the shape of x is [b,m,k], shape of W is [k,n]:

$$
y = xW
$$

then:

$$
\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}
$$

so what had happened in backward of W?

  1. We save the full $x^T$ tensor [b,k,m], and execute the batched matmul [b,k,m] x [b,m,n], finally reduce the result to [k,n] to get the gradient.
  2. We directly save the averaged $x^T$ tensor [k,m], and execute the batched matmul [k,m] x [m,n] to get the gradient.

We do (1), the second one doesn’t compute the same quantity. Would be cool to reduce ahead of time to save memory though lol.