I’m quite familiar with CNN but not in combination with Transformers.
I’d like to include a Transformer as the central element of a UNet-like architecture, similar to:
DARE-Net: Speech Dereverberation And Room Impulse Response Estimation
TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
HYBRID TRANSFORMERS FOR MUSIC SOURCE SEPARATION
My understanding is that after a few Convs I should flatten the output of the last encoding Conv, and pass it to a Transformer, then reshape the output of that Transformer back to a NCHW tensor and pass it to decoding Convs.
However reading the documentation for nn.Transformer I struggle to understand 2 points :
- a Transformer takes a SNE tensor. Am I understanding right that S should be a flat version of my 2D dimensions (HW) and E should be the hidden channels (C) ?
- why should I provide both a Source and Target tensor to the Transformer, considering I just have an input tensor and expect an output tensor ? What is the meaning of those Source and Target ?
Any clarification would be appreciated !