How does bp work in batched matmul scene? assume we have a linear, the shape of x is [b,m,k]
, shape of W is [k,n]
:
$$
y = xW
$$
then:
$$
\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}
$$
so what had happened in backward of W?
- We save the full $x^T$ tensor
[b,k,m]
, and execute the batched matmul[b,k,m] x [b,m,n]
, finally reduce the result to[k,n]
to get the gradient. - We directly save the averaged $x^T$ tensor
[k,m]
, and execute the batched matmul[k,m] x [m,n]
to get the gradient.