I had to drop the reassignment in order to avoid the error in my further experiments. I am sure that error happens at the following lines.
self.W_hat.div_(torch.norm(self.W_hat, dim=5, keepdim=True))
I will try to write a small snippet for reproduction in coming weeks, I only have one computer with gpu and that is working atm Sharing a and x would not help you, there are several steps before I use w_theta_sin I need to simplify it to make it easier for you to trace it.
I think this issue is also related, but I am using 1.1.0: