How to pass two arguments in transforms.Lambda(<func_name>))

I have a function like this:
def max_abs_normalize(img, max_abs_val):
return (img)/max_abs_val

I am creating a transform object like this:
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.Lambda(max_abs_normalize)
])

def getitem(self, index):
path = self.paths[index]
#img = Image.open(path)
img = np.load(path)
img = img.f.arr_0
img = img.reshape(1,img.shape[0],img.shape[0])
img = torch.from_numpy(img)
return self.transform(img)

How can I pass the two parameters in the transform.Lambda? for my max_abs_normalize function?

This is just a workaround, but you could do something like this.

# Define your own Lambda implementation
import torch
import torchvision
from torchvision import transforms


class MyLambda(torchvision.transforms.Lambda):
    def __init__(self, lambd, max_abs_val):
        super().__init__(lambd)
        self.max_abs_val = max_abs_val

    def __call__(self, img):
        return self.lambd(img, self.max_abs_val)

def max_abs_normalize(img, max_abs_val):
    return (img)/max_abs_val

You can then initialize your function with some value

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    MyLambda(max_abs_normalize, SOME_INITIAL_VALUE)
])

But if you want to change this, then you can access it like this

transform.transforms[-1].max_abs_val = NEW_VALUE
1 Like

Thank you. It worked!

1 Like