Creating a .pt model for sample production that can be used by a C++ program

Dear all,

I am trying to obtain a trained model flow_model.pt (produced from a modified version of https://github.com/bayesiains/nflows/blob/master/examples/conditional_moons.ipynb ) which can be later read in C++ using a command like:

module = torch::jit::load(“flow_model.pt”);

or something similar.

The important snippets of my initial PyTorch program so far are:

class model_from_flow(torch.nn.Module):
def init(self, flow):
super().init()
self.flow = flow

def forward(self, x):
    self.flow.eval()
    output_dist = self.flow.sample(num_samples=1, context=x)
    return output_dist

x, y = datasets.make_moons(128, noise=.1)
plt.scatter(x[:, 0], x[:, 1], c=y);

num_layers = 5
base_dist = ConditionalDiagonalNormal(shape=[2],
context_encoder=nn.Linear(1, 4))

transforms =
for _ in range(num_layers):
transforms.append(ReversePermutation(features=2))
transforms.append(MaskedAffineAutoregressiveTransform(features=2,
hidden_features=4,
context_features=1))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist)
flow.eval()
optimizer = optim.Adam(flow.parameters())

num_iter = 1000
for i in range(num_iter):
x, y = datasets.make_moons(128, noise=.1)
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)
optimizer.zero_grad()
loss = -flow.log_prob(inputs=x, context=y).mean()
loss.backward()
optimizer.step()

if (i + 1) % 500 == 0:
    fig, ax = plt.subplots(1, 2)
    xline = torch.linspace(-1.5, 2.5, 100)
    yline = torch.linspace(-.75, 1.25, 100)
    xgrid, ygrid = torch.meshgrid(xline, yline, indexing='ij')
    #xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

    with torch.no_grad():
        zgrid0 = flow.sample(5000, context=torch.zeros(1, 1)).exp().reshape(100, 100)
        zgrid1 = flow.sample(5000, context=torch.ones(1, 1)).exp().reshape(100, 100)

Save the model

wrapped_model = model_from_flow(flow)
wrapped_model.eval()
with torch.no_grad():
fake_input = torch.zeros(128, 1)
traced_model = torch.jit.script(wrapped_model)
traced_model.eval()
traced_model.save(“flow_model.pt”)

I get an error like this:

    raise RuntimeError("Forward method cannot be called for a Distribution object.")

I have tried using torch.jit.trace too. Can anyone point me in the right direction? What is the proper syntax to save the model?

Thank you in advance.

Best