Pytorch v2 custom transforms

I am trying to get my own custom pad and resize operator for use inside a dataloader for a resnet50 classifier.

The problem is that I cannot get this transform structure right, so that an image can be processed in a transform pipeline. The online help for v2 custom transforms is not help as there are conflicting examples that show using a “forward” def with the Torch.nn.Module, and a “transform” def with the v2.Transform module.

class SmartResizePad(v2.Transform):
#
    def __init__(self, target_size)->None:
        super().__init__()
        self.target_size = target_size

    def SmartResizePad(self, img:torch.Tensor):
        # Determine size
        if isinstance(img, Image.Image):
            img = v2.ToImage()(img)
        print(f'TS:{type(self.target_size), self.target_size} {type(img), img}')
        print(type(img), img)
        _, h, w = img.shape
#
        if w <= self.target_size and h <= self.target_size:
            pad_w = self.target_size - w
            pad_h = self.target_size - h
            padding = [pad_w // 2, pad_h // 2, pad_w - (pad_w // 2), pad_h - (pad_h // 2)]
     #
            # Use functional pad for direct application
            #PFC return v2F.pad(img, padding=padding, fill=0, padding_mode='constant')
            return self._call_kernel(F.pad, img, padding=padding, fill=0, padding_mode=self.padding_mode)  # type: ignore[arg-type]
        else:
            return F.resize(img, [self.target_size, self.target_size], antialias=True)
    #
    #def _transform(self,img):
    def transform(self, img: torch.Tensor) -> Any:
        # 1. Ensure input is a tensor for consistent handling
        # V2 flattens your input and passes each piece here
        if isinstance(img, torch.Tensor):
            return (self.SmartResizePad(self,img))
        print(f'T not an image')

When I call it using

# — 2. Data Loading —

ResizePad = SmartResizePad(255) # only resize if >256, pad smaller.
data_transforms_v2 = {
        'train': v2.Compose([
            #v2.RandomInvert(p=1.0),  # invert to get the energy into the object
            ## PFC MOVE TO AFTER RESIZE
            
            v2.RandomAffine(degrees=[-5, 5], translate=[0, 0.1], scale=[1, 1.1],
                      shear=[0, 5], fill=0, center=None),
            #v2.Resize(size=256, max_size=257),
            ResizePad.SmartResizePad(256), ## PFC try to stop smaller onjects being resized
            v2.RandomInvert(p=1.0),  # invert to get the energy into the object
            v2.CenterCrop(224),
            v2.RandomHorizontalFlip(),
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            #v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': v2.Compose([ ## not enough test data so transform
            v2.RandomAffine(degrees=[-5, 5], translate=[0, 0.1], scale=[1, 1.1],
                      shear=[0, 5], fill=0, center=None),
            #v2.Resize(256),
            ResizePad.SmartResizePad(256), ## PFC try to stop smaller onjects being resized
            v2.RandomInvert(p=1.0),  # invert to get the energy into the object
            v2.CenterCrop(224),
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms_v2[x])
                      for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=25,
                                                 shuffle=True, num_workers=4)
                  for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

The error is here when I run it:

python train_resnet50_mac.py
Using device: cpu
TS:(<class 'int'>, 255) (<class 'int'>, 256)
<class 'int'> 256
Traceback (most recent call last):
  File "train_resnet50_mac.py", line 246, in <module>
    ResizePad.SmartResizePad(256), ## PFC try to stop smaller onjects being resized
  File "train_resnet50_mac.py", line 56, in SmartResizePad
    _, h, w = img.shape
AttributeError: 'int' object has no attribute 'shape'

This suggests the image is seen as a set of INTS not as a whole image.

I thought “def SmartResizePad(self, img:torch.Tensor)” would fix the issue.

But it does not. Any pointers to a solution most appreciated.

Phil

Solved! It turned out that I had not go the correct header for def transform()

class SmartResizePad(v2.Transform):

**def** \__init_\_(self, target_size: int = 256):

    super().\__init_\_()

    self.target_size = target_size

**def** transform(self, inpt: Any, params: dict\[str, Any\]) -> Any:

    \# Get dimensions (works for PIL or Tensor)

    \_, h, w = v2.functional.get_dimensions(inpt)

    **works.**