Your code looks correct, but you might want to divide the accumulated loss by the number of accumulation steps. Also, here is a nice overview of different approaches in case you want to trade compute for memory etc.
Your code looks correct, but you might want to divide the accumulated loss by the number of accumulation steps. Also, here is a nice overview of different approaches in case you want to trade compute for memory etc.