torch.fx.proxy.Proxy error

Hello. When I run the following code

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor

def n_alphabet(n, r, c):
    return nn.Sequential(
        nn.Linear(n, r*c), 
        nn.Unflatten(1, [1, r, c]), 
        nn.Conv2d(1, 8, 3, padding=1), 
        nn.Tanh(), 
        nn.Conv2d(8, 1, 3, padding=1),
        transforms.RandomAffine(degrees=30, translate=(0.3, 0.3), scale=(0.75, 1.25)), 
        nn.Conv2d(1, 8, 3, padding=1), 
        nn.Tanh(), 
        nn.Flatten(), 
        nn.Linear(8*r*c, n), 
        nn.Softmax(dim=1)
    )

x = n_alphabet(4, 4, 4)

net = create_feature_extractor(x, return_nodes={'4':'alphabet', '10':'latent'})

I get the following error:

TypeError: Unexpected type <class 'torch.fx.proxy.Proxy'>

I know it has to do with the RandomAffine term in the sequence. I don’t understand why though. Why is this error occurring, and what can I do to correct it?

Thanks

I think when using nn.Sequential all its modules must be of type nn.Module so when you use an object which is of a different type it throws the error. I think you’ll have to write your class out explicitly as an nn.Module object.