I’m writing a deep learning model in pytorch…in the constructor, I’m defining a transformation function (the self.transform
) which will be applied to the input before the forward pass.
class CustomModel(torch.nn.Module):
def __init__(self, model, model_name: str, model_params: dict):
super().__init__()
self.model_name = model_name
self.model_params = model_params
self.model = model
# Generate the sequential transformations needed for each clip
self.transform = torch.nn.Sequential(
LongSideScale(size=self.model_params["side_size"]),
CenterCrop(size=(self.model_params["crop_size"], self.model_params["crop_size"])),
Div255(),
Normalize(self.model_params["mean"], self.model_params["std"]),
)
def forward(self, batch_video: List[torch.Tensor]) -> torch.Tensor:
batch_video = self.apply_batch_transform(batch_video)
logits = self.model(batch_video)
probability = torch.sigmoid(logits)
return probability
def apply_batch_transform(self, batch_video: List[torch.Tensor]) -> torch.Tensor:
transformed_batch_video = [self.transform(clip) for clip in batch_video]
transformed_batch_video = torch.stack(transformed_batch_video)
return transformed_batch_video
The problem is that the input can be either:
- sequence of N RGB frames: x.shape = (3, N, w, h)
- frames can be stacked: x.shape = (3*N, w, h)
This is accounted for in the implementation of the LongSideScale
part of the transformation:
def long_side_scale(
x: torch.Tensor,
size: int,
interpolation: str = "bilinear",
backend: str = "pytorch",
) -> torch.Tensor:
"""
Determines the longer spatial dim of the video (i.e. width or height) and scales
it to the given size. To maintain aspect ratio, the shorter side is then scaled
accordingly.
"""
# Check if tensor is 3D (C, H, W)
if len(x.shape) == 3:
c, h, w = x.shape
x = x.unsqueeze(1) # Add a dummy temporal dimension (T=1)
unpack = True
elif len(x.shape) == 4:
c, t, h, w = x.shape
unpack = False
else:
raise ValueError(f"Input tensor must be 3D or 4D, but got shape {x.shape}")
is_uint8 = x.dtype == torch.uint8
c, t, h, w = x.shape
if w > h:
new_h = int(math.floor((float(h) / w) * size))
new_w = size
else:
new_h = size
new_w = int(math.floor((float(w) / h) * size))
if backend == "pytorch":
if is_uint8:
x = x.to(torch.float32)
x = torch.nn.functional.interpolate(x, size=(new_h, new_w), mode=interpolation, align_corners=False)
if is_uint8:
x = x.to(torch.uint8)
# If we added a temporal dimension, squeeze it back
if unpack:
x = x.squeeze(1)
return x
else:
raise NotImplementedError(f"{backend} backend not supported.")
Finally, this model will be converted to torchscript, so something like:
model = CustomModel(model=model, model_name=model_name, model_params=model_params)
scripted = torch.jit.script(model)
scripted.save(scripted_path)
I was worried this could not be the best solution, since TorchScript could have limitations regarding dynamic control flow, such as if conditions based on tensor properties or data types, when converting to a static graph (I could be wrong, though!).
I tested the code and everything seems to be working fine, the model is saved an I can perform inference for both input scenarios. I was just wondering if it ight be an issue, and if there’s a better way to tackle this situation.