Yes, that’s a good point. The toy example is just to show how the syntax works, as the existing examples typically show simpler functions, such as ReLU and Exp.
BTW, I think the original formula is args[0].t() @ grad_output
, such that the resulting shape is correct ( k by m * m by n → k by n).