Dear all,
I am trying to obtain a trained model flow_model.pt (produced from a modified version of https://github.com/bayesiains/nflows/blob/master/examples/conditional_moons.ipynb ) which can be later read in C++ using a command like:
module = torch::jit::load(“flow_model.pt”);
or something similar.
The important snippets of my initial PyTorch program so far are:
class model_from_flow(torch.nn.Module):
def init(self, flow):
super().init()
self.flow = flow
def forward(self, x):
self.flow.eval()
output_dist = self.flow.sample(num_samples=1, context=x)
return output_dist
x, y = datasets.make_moons(128, noise=.1)
plt.scatter(x[:, 0], x[:, 1], c=y);
num_layers = 5
base_dist = ConditionalDiagonalNormal(shape=[2],
context_encoder=nn.Linear(1, 4))
transforms =
for _ in range(num_layers):
transforms.append(ReversePermutation(features=2))
transforms.append(MaskedAffineAutoregressiveTransform(features=2,
hidden_features=4,
context_features=1))
transform = CompositeTransform(transforms)
flow = Flow(transform, base_dist)
flow.eval()
optimizer = optim.Adam(flow.parameters())
num_iter = 1000
for i in range(num_iter):
x, y = datasets.make_moons(128, noise=.1)
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)
optimizer.zero_grad()
loss = -flow.log_prob(inputs=x, context=y).mean()
loss.backward()
optimizer.step()
if (i + 1) % 500 == 0:
fig, ax = plt.subplots(1, 2)
xline = torch.linspace(-1.5, 2.5, 100)
yline = torch.linspace(-.75, 1.25, 100)
xgrid, ygrid = torch.meshgrid(xline, yline, indexing='ij')
#xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)
with torch.no_grad():
zgrid0 = flow.sample(5000, context=torch.zeros(1, 1)).exp().reshape(100, 100)
zgrid1 = flow.sample(5000, context=torch.ones(1, 1)).exp().reshape(100, 100)
Save the model
wrapped_model = model_from_flow(flow)
wrapped_model.eval()
with torch.no_grad():
fake_input = torch.zeros(128, 1)
traced_model = torch.jit.script(wrapped_model)
traced_model.eval()
traced_model.save(“flow_model.pt”)
I get an error like this:
raise RuntimeError("Forward method cannot be called for a Distribution object.")
I have tried using torch.jit.trace too. Can anyone point me in the right direction? What is the proper syntax to save the model?
Thank you in advance.
Best