How to pass loss function from python code to C++ extension?

I am fairly new to PyTorch and only very recently started using the PyTorch C++ API. So forgive me if my question is trivial. This is also my first question in the PyTorch forum, so please suggest if any reformatting I should be doing to my question.

I have the C++ code like this,

#include <torch/extension.h>
#include <iostream>

namespace py = pybind11;

torch::Tensor calculateLoss(
    torch::Tensor pred,
    torch::Tensor target,
    <pytorch_loss_fn> fn
){
    return fn(pred,target);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("calculateLoss", &calculateLoss, "Calculate Loss In C++");
}

I want to pass the loss function from the python code. Something like this,

import torch
import dummy_cpp

from torch.nn import BCELoss

y_p = torch.Tensor([0.25,0.75])
y = torch.Tensor([0.0,1.0])

print(dummy_cpp.calculateLoss(y_p,y,BCELoss()))

Similar to Tensors, do we have some class in C++ that does this implicit transformation from PyTorch Loss Function object to C++ Torch Loss Function object? I tried putting nn::AnyModule but that doesn’t work. Is there any way to achieve this? I am also ok with solutions that perform a constant time operation to make a copy of the torch loss function in C++, i.e. something like this,

print(dummy_cpp.calculateLoss(y_p,y,processPython(BCELoss())))
auto fn = processCpp(<processed Object/Tensor from python>);
return fn(pred,target);

Some additional questions related to this,

  1. I don’t want to use the default way of passing a python function to C++ using pybind11, as due to repeated wrapping of C++ object to Python and vice versa, it affects the performance a lot. And as I will be calling the loss function multiple times from a loop I can’t have any performance overhead. Does this reasoning is also valid for the torch c++ extension? I have this doubt because PyTorch tensors are getting implicitly converted to C++ tensors, so maybe the overhead doesn’t exist?

  2. I would also like to learn if we can pass activation functions and custom loss functions (inherited from torch.autograd.Function) to C++ extension in similar manner.