I don’t understand why this small code produces an Illegal memory access:
class PhyCell_Cell(nn.Module):
def __init__(self, ch_in, hidden_dim, ks=3, bias=True, group=True):
super().__init__()
padding = ks // 2
bias = bias
self.f = nn.Sequential(
nn.GroupNorm(4, ch_in) if group else nn.BatchNorm2d(ch_in),
nn.Conv2d(ch_in, hidden_dim, ks, padding=padding),
nn.Conv2d(hidden_dim, ch_in, kernel_size=(1,1)))
self.convgate = nn.Conv2d(2*ch_in,
ch_in,
kernel_size=(3,3),
padding=(1,1),
bias=bias)
def forward(self, x, hidden=None):
"x ~[batch_size, hidden_dim, height, width]"
if hidden is None: hidden = self.init_hidden(x)
hidden_tilde = hidden + self.f(hidden)
combined = torch.cat([x, hidden_tilde], dim=1)
combined_conv = self.convgate(combined)
K = torch.sigmoid(combined_conv)
next_hidden = hidden_tilde + K * (x - hidden_tilde)
return next_hidden
def init_hidden(self, x):
bs, ch, h, w = x.shape
return one_param(self).new_zeros(bs, ch, h, w)
small reproductible example on colab:
https://colab.research.google.com/drive/1YCEbnDHLWnZ49kGfg4FX0EBQ2RTbGaat?usp=sharing
If I replace with BatchNrom (what I am doing) it works.