Hi everyone,
I would like to start a discussion on the potential need for in-place real-domain FFT operators
in PyTorch, primarily motivated by memory-efficient training scenarios.
Motivation
In large-scale model training and parameter-efficient fine-tuning, FFT-based components
(e.g., circulant or Fourier-based layers) are increasingly used.
However, existing APIs such as torch.fft.rfft/irfft allocate additional intermediate buffers,
which can introduce noticeable memory overhead during training, especially for bf16 workloads
and large models.
An in-place real-domain FFT could help reduce memory usage by avoiding these extra allocations.
Scope of discussion
At this stage, the goal is not to propose a finalized API, but to understand:
- Whether this is a sufficiently common use case in the community
- Whether such functionality would be better suited for PyTorch core or as an extension
- Any high-level concerns regarding safety, maintainability, or API design
A related feature request issue has been opened on GitHub for maintainers’ feedback:
In-place real-domain FFT operators for memory-efficient training (bf16 support) · Issue #171022 · pytorch/pytorch
Any thoughts or use cases from the community would be very helpful.
Thanks!