Torchscript compatibility with control flows

I’m writing a deep learning model in pytorch…in the constructor, I’m defining a transformation function (the self.transform) which will be applied to the input before the forward pass.

class CustomModel(torch.nn.Module):

    def __init__(self, model, model_name: str, model_params: dict):
        super().__init__()
        self.model_name = model_name
        self.model_params = model_params
        self.model = model

        # Generate the sequential transformations needed for each clip
        self.transform = torch.nn.Sequential(
            LongSideScale(size=self.model_params["side_size"]),
            CenterCrop(size=(self.model_params["crop_size"], self.model_params["crop_size"])),
            Div255(),
            Normalize(self.model_params["mean"], self.model_params["std"]),
        )

    def forward(self, batch_video: List[torch.Tensor]) -> torch.Tensor:
        batch_video = self.apply_batch_transform(batch_video)
        logits = self.model(batch_video)
        probability = torch.sigmoid(logits)
        return probability

    def apply_batch_transform(self, batch_video: List[torch.Tensor]) -> torch.Tensor:
        transformed_batch_video = [self.transform(clip) for clip in batch_video]
        transformed_batch_video = torch.stack(transformed_batch_video)
        return transformed_batch_video

The problem is that the input can be either:

  • sequence of N RGB frames: x.shape = (3, N, w, h)
  • frames can be stacked: x.shape = (3*N, w, h)

This is accounted for in the implementation of the LongSideScale part of the transformation:

def long_side_scale(
    x: torch.Tensor,
    size: int,
    interpolation: str = "bilinear",
    backend: str = "pytorch",
) -> torch.Tensor:
    """
    Determines the longer spatial dim of the video (i.e. width or height) and scales
    it to the given size. To maintain aspect ratio, the shorter side is then scaled
    accordingly.
    """

    # Check if tensor is 3D (C, H, W)
    if len(x.shape) == 3:
        c, h, w = x.shape
        x = x.unsqueeze(1)  # Add a dummy temporal dimension (T=1)
        unpack = True
    elif len(x.shape) == 4:
        c, t, h, w = x.shape
        unpack = False
    else:
        raise ValueError(f"Input tensor must be 3D or 4D, but got shape {x.shape}")

    is_uint8 = x.dtype == torch.uint8
    c, t, h, w = x.shape
    if w > h:
        new_h = int(math.floor((float(h) / w) * size))
        new_w = size
    else:
        new_h = size
        new_w = int(math.floor((float(w) / h) * size))

    if backend == "pytorch":
        if is_uint8:
            x = x.to(torch.float32)
        x = torch.nn.functional.interpolate(x, size=(new_h, new_w), mode=interpolation, align_corners=False)
        if is_uint8:
            x = x.to(torch.uint8)

        # If we added a temporal dimension, squeeze it back
        if unpack:
            x = x.squeeze(1)
        return x
    else:
        raise NotImplementedError(f"{backend} backend not supported.")

Finally, this model will be converted to torchscript, so something like:

model = CustomModel(model=model, model_name=model_name, model_params=model_params)
scripted = torch.jit.script(model)
scripted.save(scripted_path)

I was worried this could not be the best solution, since TorchScript could have limitations regarding dynamic control flow, such as if conditions based on tensor properties or data types, when converting to a static graph (I could be wrong, though!).
I tested the code and everything seems to be working fine, the model is saved an I can perform inference for both input scenarios. I was just wondering if it ight be an issue, and if there’s a better way to tackle this situation.