Multi-modal fusion Transformer


I would like to implement the “Multimodal Fusion” block from the following approach :

It is about image captioning for a 3D scene, and training is performed with next token prediction.
As you can see, the Multimodal Fusion takes as input visual embeddings (blue) and textual embeddings (green). Textual embeddings are then updated with the Multimodal Fusion module taking into account visual embeddings when computing attentions.

I’m very new to multimodal attention, so my question is : how to implement this module ? Taking into account the fact that the number of input tokens may change ?

Does the code exist? If so, where can I find it?

Thank you so much in advance for your help !

Esentiallt it’s a pytorch transformer encoder module.
As you will find there, it asks for a mask as kwargs.

Let’s say you pass 100 elements
40 will be video, 60 text.
You will need to pass forward(cat(video,text),mask)
And this will create a sequence of 100 elements.
You pick the last 60 ones (actually you can pick any 60 as the network will learn to positon them properly)
if you signal lengths are different you need to add padding.

Sorry not to be more concise :slight_smile: but lack of time. Hope it helps!