Pitch shift an STFT?

I have the result of a torch.stft().
I want to perform a pitch shift on the audio.
The end result should be an STFT of the pitch-shifted audio.

I don’t want to time stretch, istft(), resample(), then stft() because it seems like that would be slow.

Instead, I wrote some code that is supposed to make a new spectrogram, where the nth frequency bin is just the n * scaling_factorth frequency bin of the original spectrogram, interpolated for fractional indices:

def interpolate(frequencies: torch.Tensor, sgram: torch.Tensor):
    start = frequencies.int()
    frac = (frequencies - start)[:, None]
    return sgram[start, :] * (1 - frac) + sgram[start + 1, :] * frac

def pitch_shift_spectrogram(sgram: torch.Tensor, semitones: torch.Tensor):
    scaling_factor = 2 ** (-semitones / 12)
    frequencies = torch.arange(0, sgram.shape[0], 1, device=sgram.device)
    shifted_frequencies = frequencies * scaling_factor
    shifted_sgram = interpolate(shifted_frequencies, sgram)
    return shifted_sgram

The semitones parameter is a tensor because the pitch is what the network is going to learn, so it needs a gradient.

This code sort of works?
It does return a new spectrogram, and the audio from that spectrogram kind of sounds like a pitch-shifted version of the audio, but it doesn’t sound right.
First of all, there is phasiness/distortion/artifacts (was I supposed to do something with the phases?).
But also, the pitches in the bass don’t sound right.

What’s the correct code for this? Is this faster than the inverse STFT and resampling method? Will the autograd still backprop through this?