Equivalent operation to tf.matrix_set_diag?

Is there any PyTorch operation that is similar to tf.matrix_set_diag? This returns a batched matrix tensor with new batched diagonal values. I need to zero out the diagonal of each block (N, N) in a (B, H, N, N) tensor in some efficient way.