Sorry for the late reply - been trying to wrestle this into submission without much success. The GPU memory is absolutely getting destroyed, and unfortunately there doesn’t seem to be a straightforward fix. Not sure how Pytorch is handling the dynamic graph memory allocation to be honest - I’ve read a good amount on the form about the CUDA caching behavior not being ideal however.
Will test the following flags mentioned in this thread today, and report back: