Libtorch C++ what is the best approach to access submodule list

for a pytorch snippet that gets the list of submodules of a module and check the module type

   for m in self.modules():
     if isinstance(m, nn.BatchNorm2d):
       #do_something

Is this a good C++ implementation of it? My concern is that I am using the shared_ptr reference that libtorcy try to hide with ModuleHolder

  for (auto m : children()) {
    if (dynamic_cast<torch::nn::BatchNorm2dImpl *>(m.get())) {
      //do_something
    }
  }

The submodules are already registered via register_module() in constructor.

1 Like

found the proper API as()

for (auto item : named_children()) {
  if (item.value()->as<torch::nn::BatchNorm2d>()) {
  //do_something
}
2 Likes