Memory conscious resized matrix multiplication

Hello, first time posting here.

I have become interested in transformers and their application in vision. This being said, torch.matmul of two flattened vectors representing all of the pixels in an image is virtually unusable with images much larger than a thumbnail.

For this reason I have been looking at scaled down versions of transformers, but something doesn’t seem right to me.

Generally, If I want to scale attention down (to say size [Batch, ((0.25 * Height) * (0.25 * Width))^2], lets say [1, 256 * 256]) and want to multiply it by my values at full scale (of size [Batch, Height * Width], so [1, 64 * 64] in our example) I would need to scale the attention values up and them perform matrix multiplication with torch.matmul:

# weights.shape = [1, 256, 256] in our case, else [1, 4096, 4096] at full scale
# flat_v.shape = [1, 4096]
attn = torch.matmul(weights, flat_v.transpose(2, 3))
# attn.shape = [1, 4096, 1] is what we want in both cases

This seems very poorly optimized for memory and I am sure there must be a smarter way to approach this problem in pytorch. Is there some concerted operation that will allow me to multiply the vectors together without having to overwhelm my memory? I dont need to interpolate values when scaling for this implementation (all values in the region of interest can be the same).