Inference breaks in Pytorch lightning model

Hello,
I’m trying to do inference on a ResNet50 model for an image regression task I’m working on. I’ve changed the last layer of the model to output one value, and the starting layer to accept greyscale images (250x250). I’m using WebDataSet to import the data. However I can’t get inference to produce a good answer when using the shard from the validation set. I know the model is working as I can put a break point in the validation and compare the labels to the output from the model and calculating the MSE manually I get a good match to my validation epoch. I’m using the following code;

    def inference_data_set(self, batch):
        imgs, _ = batch
        imgs = imgs.to(self.device)
        self.eval()
        with torch.no_grad():
            output = self.forward(imgs)
            
        output = output.squeeze()
        output = output*self.max_ions
        return output

max_ions converts the output of the network into the numer of ions I’m expecting there to be in the image. However, when I use inference_data_set on one of the validation shards, I get absolute nonsense with values far out of range. Is there anything I’m obviously doing wrong here?