Blown up gradients and loss

I am not sure about your code, but the first thing I would do is move optimizer.zero_grad() after moving the data to device.

Although might have nothing to do with your problems, try to use lr_scheduler to decay the learning rate, something like:

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_milestones , gamma= lr_gamma)

Then, replace if (epoch==170) block with
scheduler.step()