Zero'ing one input's gradient for Matrix Multiply

in the simplest case, you can just do something like x.detach() @ weights

1 Like