Attn feature agrregration

I have 5 images one being a source image with 4 similar images. I’m using faiss to retrieve the source image. I’m trying to have the source image attend off the other 4 images to get a more enhanced representation. I’m thinking of using the source image as the query matrix and concatenate the other images across the channel and use that as the key, value matrix. Once I get the output how can I leverage the attn weights to better enhance my source image representation. I need the final representation to be in B * C * H * W. or [B, Sequence length, embed dimension] I’m using this as a hidden state for a diffusion model. Any suggestions???