Changing where optim.step() writes to

Hello!

I am working on a niche project that requires storage of network weights for k iterations. It is easy enough to get the weights at a given iteration in PyTorch, however there is an issue. It seems as though optim.step() updates the weights of PyTorch tensors in place, meaning an expensive copy has to be performed for each of the k iterations. This drastically alters the runtime of the network.

To cut down on this, I would like to pre-allocate a block of memory on device with size k*|W|, and have optim.step() update a chunk of this at each iteration. To do this efficiently I would like to stride the output pointer of optim.step() along this block of memory, and write to each chunk of size |W|.

Is something like this possible in PyTorch?

Thanks for taking the time to think about this!