I assume you’ve implemented the VectorQuantizerEMA module?
Are you seeing the OOM directly in the first iteration or are you seeing an increased memory usage?
In the former case, you might need to reduce the batch size or if that’s not possible you could try to use torch.utils.checkpoint to trade compute for memory.
One possibility of an increased memory usage might be the storage of the computation graph. embed, cluster_size, and ema_embed are created as buffers, which would register the tensors without making them trainable (their requires_grad attribute would be False).
However, in the forward method you are reassigning some values to these buffers.
Could you check, if their grad_fn is suddently pointing to a valid function and if so, use detach() on the assignments:
self.cluster_size = (
self.decay * self.cluster_size + \
(1 - self.decay) * ref_counts
# same for every other buffer assignment