Affine_grid and grid_sample with scale factor biggern than 1 makes image smaller

hi guys,
this is probably a stupid question, but I cannot figure out following problem:
I’m currently trying to understand why my written method returns a smaller image when the scale factor is bigger than 1. Usually a larger scale factor makes the output bigger. But with grid_sample it the way around. I’m not sure if I understand the affine_grid function correctly.

import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from PIL import Image

def transform_2D(img, angle: float, translate_x: float = 0, translate_y: float = 0, scale_x: float = 1, scale_y: float = 1):
    img_tensor = TF.to_tensor(img).unsqueeze(0)
    _, channels, height, width = img_tensor.size()

    angle = torch.tensor(angle)
    # in pixels
    tx = 2*(translate_x/width)
    ty = 2*(translate_y/height)
    
    theta = torch.deg2rad(angle)   # Rotation angle in radians
    sx = torch.tensor(scale_x)     # Scaling factor along x-axis
    sy = torch.tensor(scale_y)     # Scaling factor along y-axis
    dx = torch.tensor(tx)          # Translation along x-axis
    dy = torch.tensor(ty)          # Translation along y-axis
    
    # Define affine matrix
    affine_matrix = torch.tensor([
            [sx * torch.cos(theta), -torch.sin(theta), dx],
            [torch.sin(theta),  sy * torch.cos(theta), dy]])

    # Generate grid
    grid = F.affine_grid(affine_matrix.unsqueeze(0), img_tensor.size(), align_corners=True)
    # Apply affine transformation using grid_sample with border control
    transformed_image = F.grid_sample(img_tensor, grid, padding_mode='border')
    return transformed_image

usage:

transformed_image = transform_2D(img=image,
                                 angle=0,
                                 translate_x=0,
                                 translate_y=0,
                                 scale_x=1,
                                 scale_y=1
                                )
pil_image = TF.to_pil_image(transformed_image.squeeze(0))
display(pil_image)