I would have assumed you might be running into this issue.
However, I don’t see this behavior in your code.
Could you replace the inplace skip connections h += poolX
with their out-of-place versions h = h + poolX
and check, if this would solve the issue?
I guess h
is needed for a gradient calculation in some layers, which will break if you modify it inplace.