How to get the original class name of a RecursiveScriptModule?

Hi All,

I just have a quick question regarding RecursiveScriptModule, I’ve built a custom optimizer and within it I cycle through all layers via for module in net.modules() and call the name of the module via module.__class__.__name___. I need this as my optimizer handles different layers differently, and I determine what type of layer it is via module.__class__.__name___.

However, if I jit my network this fails because all the name of all the modules become RecursiveScriptModule, I did notice that when printing out module it has RecursiveScriptModule(original_name=Linear). So, I was wondering how can I get the orginal_name variable here?

Thanks in advance! :slight_smile:

module.original_name would return a string containing the original module name, e.g.:

lin = nn.Linear(1, 1)
s = torch.jit.script(lin)
print(s.original_name)
> 'Linear'

so you could try to use this attribute instead.

1 Like

Thank you! Just what I wanted! :slight_smile: