PyTorch separate image by red and green channel

I have a set of images that contains green and red cells. I would like to be able to count the green cells and red cells. The current code calculate both type of cells. I would like to separate them. Any advice would appreciated. Thanks

class CellsDataset(Dataset):
    # a very simple dataset

    def __init__(self, root_dir, transform=None, return_filenames=False):
        self.root = root_dir
        self.transform = transform
        self.return_filenames = return_filenames
        self.files = [os.path.join(self.root,filename) for filename in os.listdir(self.root)]
        self.files = [path for path in self.files
                      if os.path.isfile(path) and os.path.splitext(path)[1]=='.png']

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        sample =
        #transform3 = Grayscale(num_output_channels=3)
        #sample = transform3(sample) # convert to a 3 channel grayscale, as it needs to be 3 channel.
        if self.transform:
            sample = self.transform(sample)

        if self.return_filenames:
            return sample, path
            return sample

How is the current code counting these cells?
Based on the code snippet you shared, it seems you are only loading the data.