Yeah, that’s it! Thanks a lot for checking it out.
I experimented with different settings of cudnn and found that calling
torch.backends.cudnn.deterministic = True
was sufficient to solve the issue.
Some additional info with respect to runtime per batch for future readers (ii and iii solve the issue):
i: default settings (i.e. non-deterministic)
------> 0.51s
ii: torch.backends.cudnn.enabled = False
------> 0.14s
iii: torch.backends.cudnn.deterministic=True
------> 0.002s
Note the speed-up for this model. Reasons for speed-up of deterministic algorithms was discussed here.