Masked average pooling

Say I wanted to replace global average pooling (i.e. the one at the end of resnet’s) with a layer that, for each feature map, averages only the top n neurons and ignores the rest (where n < 7x7 or whatever the dimensions of the final conv output are). What would be the fastest way to implement that?

I wasn’t sure how, but maybe flatten each feature map and then use topk and take the mean of that over the feature map axes. But, I don’t think topk works that way…any thoughts?