I have tensors a_angle and b_angle each containing 0.0-1.0 values representing 0-360 degrees.

I want to prevent a case where distance between two points become larger than the actual distance due to the discontinuity around 0 degree.
To this end, I am addting or subtracting 1 (360 degree) when the distance between two points cross over 0.
I have the following code, but it is very slow.
Is there anyway to do this more efficiently? Perhaps I should use masks of indices or torch.where but not sure how to perform that computation in my use case.

for i in range(len(a_angle)):
if b_angle[i] > 0.5 and 0.5 > a_angle[i] and b_angle[i]-a_angle[i] > 0.5:
b_angle[i] = torch.add(b_angle[i], -1)
elif a_angle[i] >0.5 and 0.5 > b_angle[i] and a_angle[i]-b_angle[i] > 0.5:
a_angle[i] = torch.add(a_angle[i], -1)