Computing recall on very large datasets

Hello,

I am training a model by means of triplet loss to generate image and point cloud descriptors. In order to evaluate its performance, I am plotting the training loss, recall@6 (i.e., given and anchor image, I select the 6 closer point clouds in the database and check if ground truth is within that retrieved data) on the training split and recall@6 on validation split.

Validation split is relatively small and I can compute recall after each epoch, but training split is too large to fit into GPU memory. Is there a way to compute this recall measure on the whole training split in PyTorch without having memory overflow issues? Doing the processing on CPU is too time expensive.

I thought that maybe obtaining the recall@6 for each batch and then computing the mean to get the recall@6 at a certain epoch would give me an approximation, but I am seeing that results are very different between following that procedure and computing the recall on the whole dataset after each epoch iteration.

In order to check whether the model is over fitting or not, would it be a good idea to calculate the loss on the validation split? Or that is not a good idea when using triplet loss?

Thanks!