Zero'ing one input's gradient for Matrix Multiply

Yes, that’s a good point. The toy example is just to show how the syntax works, as the existing examples typically show simpler functions, such as ReLU and Exp.

BTW, I think the original formula is args[0].t() @ grad_output, such that the resulting shape is correct ( k by m * m by n → k by n).