Are you using view() or reshape() to get the output into the required shape? If so, you might want to look at this post. In a nutshell, “carelessly” using view() or reshape() will messed up your output which very likely will lead to your network not learning properly.