Hi,
What is the grad input mask that is passed in the layer norm kernel call(pytorch/layer_norm_kernel.cu at master · pytorch/pytorch · GitHub). I notice it invokes a bunch of kernels like LayerNormBackward whenever it is set. Can someone tell me what it actually means and when it is used/not used?.
Thanks.