Illegal memory access when I use GroupNorm

I don’t understand why this small code produces an Illegal memory access:

class PhyCell_Cell(nn.Module):
    def __init__(self, ch_in, hidden_dim, ks=3, bias=True, group=True):
        padding = ks // 2
        bias = bias
        self.f = nn.Sequential(
                 nn.GroupNorm(4, ch_in) if group else nn.BatchNorm2d(ch_in),   
                 nn.Conv2d(ch_in, hidden_dim, ks, padding=padding),
                 nn.Conv2d(hidden_dim, ch_in, kernel_size=(1,1)))

        self.convgate = nn.Conv2d(2*ch_in,

    def forward(self, x, hidden=None): 
        "x ~[batch_size, hidden_dim, height, width]"  
        if hidden is None: hidden = self.init_hidden(x)
        hidden_tilde = hidden + self.f(hidden)
        combined =[x, hidden_tilde], dim=1)
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        next_hidden = hidden_tilde + K * (x - hidden_tilde)
        return next_hidden
    def init_hidden(self, x):
        bs, ch, h, w = x.shape
        return one_param(self).new_zeros(bs, ch, h, w)

small reproductible example on colab:
If I replace with BatchNrom (what I am doing) it works.

Thanks for reporting this issue and providing the code snippet!
We could reproduce it and Xiao fixed it in this PR.
It should thus be available in the next nightly build.

some older version of pytorch seems to work.