I would recommend checking the activation shapes and to compare them to the reported memory usage, as the high usage might be expected.
Generally, the memory usage comes from the parameter, the stored forward activations needed for the gradient computation, the gradients, and optimzer states (if applicable as it depends on the used optimizer).
This post gives you a simplified example and shows that the stored forward activations take much more memory than e.g. the parameters.
Based on the error message I would guess self.sa6
tries to allocate the 32GB of memory as the corresponding print
statement is missing.
1 Like