I am working on a torch to tensorrt project. currently the major problem is impossible to get correct weight of an op. A traced resnet18 model produces following inputs:
node: %input.1 : Float(1, 3, 224, 224), %702 : Tensor, %703 : Tensor, %704 : Tensor, %705 : Tensor, %706 : Tensor, %707 : Tensor, %708 : Tensor, %709 : Tensor, %710 : Tensor, %711 : Tensor, %712 : Tensor, %713 : Tensor, %714 : Tensor, %715 : Tensor, %716 : Tensor, %717 : Tensor, %718 : Tensor, %719 : Tensor, %720 : Tensor, %721 : Tensor, %722 : Tensor, %723 : Tensor, %724 : Tensor, %725 : Tensor, %726 : Tensor, %727 : Tensor, %728 : Tensor, %729 : Tensor, %730 : Tensor, %731 : Tensor, %732 : Tensor, %733 : Tensor, %734 : Tensor, %735 : Tensor, %736 : Tensor, %737 : Tensor, %738 : Tensor, %739 : Tensor, %740 : Tensor, %741 : Tensor, %742 : Tensor, %743 : Tensor, %744 : Tensor, %745 : Tensor, %746 : Tensor, %747 : Tensor, %748 : Tensor, %749 : Tensor, %750 : Tensor, %751 : Tensor, %752 : Tensor, %753 : Tensor, %754 : Tensor, %755 : Tensor, %756 : Tensor, %757 : Tensor, %758 : Tensor, %759 : Tensor, %760 : Tensor, %761 : Tensor, %762 : Tensor, %763 : Tensor, %764 : Tensor, %765 : Tensor, %766 : Tensor, %767 : Tensor, %768 : Tensor, %769 : Tensor, %770 : Tensor, %771 : Tensor, %772 : Tensor, %773 : Tensor, %774 : Tensor, %775 : Tensor, %776 : Tensor, %777 : Tensor, %778 : Tensor, %779 : Tensor, %780 : Tensor, %781 : Tensor, %782 : Tensor, %783 : Tensor, %784 : Tensor, %785 : Tensor, %786 : Tensor, %787 : Tensor, %788 : Tensor, %789 : Tensor, %790 : Tensor, %791 : Tensor, %792 : Tensor, %793 : Tensor, %794 : Tensor, %795 : Tensor, %796 : Tensor, %797 : Tensor, %798 : Tensor, %799 : Tensor, %800 : Tensor, %801 : Tensor, %802 : Tensor, %803 : Tensor = prim::Param()
It’s possible to get correct input nodes, but for parameter nodes, the only information I can get is “index” of slot, I don’t know how to get corresponding weight of a parameter node.
A workaround is use torch.jit._unique_state_dict and remove all untracked variables to get a list of params, then assign them to param node in reversed order. but this isn’t work for models with unused modules such as torchvision.models.inception_v3 (it has a aux output).
Thanks in advance!
Sorry I don’t quite get what you are asking for. what do you exactly mean the weight of an op
? and the correct input node
? If you can provide more context that will be good for us to answer your exact question
torch.jit.trace create a graph, graph.inputs() return input nodes in net.forward and parameter nodes, the problem is there is no way to get corresponding weight tensor for a parameter node.
I currently use this code to get weight tensor to parameter node mapping, but this isn’t guaranteed by pytorch doc.