How to use nn.Transformer in the context of vision?

I’m quite familiar with CNN but not in combination with Transformers.

I’d like to include a Transformer as the central element of a UNet-like architecture, similar to:
DARE-Net: Speech Dereverberation And Room Impulse Response Estimation
TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
HYBRID TRANSFORMERS FOR MUSIC SOURCE SEPARATION

My understanding is that after a few Convs I should flatten the output of the last encoding Conv, and pass it to a Transformer, then reshape the output of that Transformer back to a NCHW tensor and pass it to decoding Convs.

However reading the documentation for nn.Transformer I struggle to understand 2 points :

  • a Transformer takes a SNE tensor. Am I understanding right that S should be a flat version of my 2D dimensions (HW) and E should be the hidden channels (C) ?
  • why should I provide both a Source and Target tensor to the Transformer, considering I just have an input tensor and expect an output tensor ? What is the meaning of those Source and Target ?

Any clarification would be appreciated !

1 Like