Looking for insights on debugging issues loading state dict of a net trained on GPU on CPU

Hi All,

I’m at a loss for the time being on how to proceed with debugging problematic loading of a GPU-trained model on CPU. The model is a generative flow that I sample time series information from. I can provide more detail as necessary, but starting more simply for now, here’s the steps of what I’m doing as I go by the documentation for loading a GPU net on CPU.

device = torch.device('cpu')
net = SDEFlow(device, obs_model, state_dim_SCON, t, dt_flow, n, num_layers = 5) #instantiates model, leaving arguments for future explanation if necessary
net.load_state_dict(torch.load(net_state_dict_save_string), map_location = 'cpu')

net.eval()
x, _ = net(batch_size) #Forward method of net outputs two things, and x is the matrix time series.

#various plotting code from here

and then I end up with a problematic figure that doesn’t align with my GPU results:

Re-loading the model on the device on which the training occurred

device = torch.device('cuda')
net = SDEFlow(device, obs_model, state_dim_SCON, t, dt_flow, n, num_layers = 5)
net.load_state_dict(torch.load(net_state_dict_save_string), map_location = 'cuda')

net.eval()
x, _ = net(batch_size)

#various plotting code from here.

I’m able to get back the results that I expect:

I don’t think the bug is in my plotting code, since I can already tell things are wrong via print output of x on the CPU, so I didn’t include my plotting code for now either to avoid extraneous information.

Am I missing something obviously wrong with respect to the order that I’m doing things? I made sure that the flow model being instantiated was the exact same, so it appears that something is not being converted correctly in the saving of the state dict? In the training, I’m saving the net state dict along the lines of torch.save(net.state_dict(), net_state_dict_file_string). Evaluation batch_size is equivalent on both GPU and CPU.

Happy to provide more info. I appreciate your time and patience in reading through this question.

Do I understand the issue correctly that your model is performing as expected if you are training and restoring the model on the GPU, but fails if you are training on the GPU and restoring on the CPU?
If so, which PyTorch version are you using? And if not the latest one, could you update it to the latest stable or nightly release?
A while ago (maybe ~6 months) I’ve debugged an issue where the compression during torch.save created invalid data when e.g. nn.Embedding layers were stored from the GPU (it’s already fixed). If you could resave the model, could you disable the compression via torch.save(..., _use_new_zipfile_serialization=False) and check if reloading this state_dict would work?

1 Like

Thanks for the unbelievably fast follow-up, @ptrblck. That is absolutely the issue. On GPU, the PyTorch version I am using is 1.10.0, and I am locked to that version due to not having admin privileges. On CPU, the PyTorch version I am using is 1.7.1, and I will upgrade promptly.

A while ago (maybe ~6 months) I’ve debugged an issue where the compression during torch.save created invalid data when e.g. nn.Embedding layers were stored from the GPU (it’s already fixed). If you could resave the model, could you disable the compression via torch.save(..., _use_new_zipfile_serialization=False) and check if reloading this state_dict would work?

I will give this a shot and revert.

By the way, does the addition of _use_new_zipfile_serialization=False apply to both the net and net.state_dict() file, or did the compression only affect the state_dict saving?

You shouldn’t save the net directly, as it would depend on the actual source files and could easily break.
The recommended way it so save the state_dict and create a new model object before loading the state_dict back.
Are you saving both, the actual model object as well as the state_dict?

Yes, I’m saving both. Since I haven’t changed my model and its dependencies for a while, I’ve tended to load the net directly on GPU (though of course I did make sure to load the .state_dict to a new model object on GPU this time).

Try to avoid saving/loading the model directly (just save/load the state_dict and create a new model instance), as it could depend on the actually installed PyTorch version and could even be related to the issue you are seeing.
I’m personally not using the workflow of saving the models directly, as it breaks too easily.

1 Like

Following up on this, I’ve updated my CPU PyTorch version to 1.10.2

>>> import torch
>>> print(torch.__version__)
1.10.2

As I’ve mentioned, the version I have access to on my GPU computing cluster is locked.

After adding the _use_new_zipfile_serialization=False argument and re-loading the state_dict to the model, the CPU output is better, but still off.

Meanwhile, instantiating a new model and re-loading the state_dict on the GPU gets back what I expect:

I’ve made sure that the model module codes on CPU and GPU are the exact same! It’s even more odd to me now that things weirdly improved after two things changed (CPU PyTorch version and saving argument addition), but aren’t correct yet.

Does this point to my needing to find some way of getting PyTorch updated past 1.10.0 on the GPU cluster I’m using?

Thanks again, @ptrblck.

EDIT: I was able to get the PyTorch install updated to 1.10.2, but still using cudatoolkit = 10.2. Not able to use cudatoolkit = 10.3 on my institution’s GPU cluster. Will re-run my trainings and revert back.

After updating PyTorch on both my CPU and GPU cluster to 1.10.2 (with GPU PyTorch being installed with cudatoolkit = 10.2 since cudatoolkit = 11.3 does not work on my institution’s cluster), the net.eval discrepancy between CPU and GPU persists, and it seems that the earlier pseudo-improvement of the CPU plots has not been observed in my latest trial.

The CPU plot:

The GPU plot:

The model code and the loading steps are the exact same. The only thing different between the CPU and GPU evaluation plotting codes at this point is the map_location for the torch.load(...) calls. The CPU results at least are consistently wrong for me for multiple devices I’ve tested that are now using the same PyTorch version (1.10.2). The shared torch.load leading up to the plotting is along the lines of

obs_model = torch.load(obs_model_save_string, map_location = active_device)
net = SDEFlow(active_device, obs_model, state_dim_SCON, t, dt_flow, n, NUM_LAYERS = num_layers, REVERSE = reverse, BASE_STATE = base_state)
net.load_state_dict(torch.load(net_state_dict_save_string, map_location = active_device))
p_theta = torch.load(p_theta_save_string, map_location = active_device)
q_theta = torch.load(q_theta_save_string, map_location = active_device)
SBM_SDE_instance = torch.load(SBM_SDE_instance_save_string, map_location = active_device)

net.eval()
x, _ = net(eval_batch_size)

plot_func(...)

I made sure I was still saving the state_dict with _use_new_zipfile_serialization = False.

Could you remove the direct loading of the obs_model and use the recommended workflow of creating the model instance and loading its state_dict afterwards?
If you get stuck, feel free to provide a minimal, executable code snippet to reproduce the issue.

obs_model is actually not a trained model with a state_dict and is a D.normal.Normal distribution object, but I created it anew separately on the CPU anyway and got the same result:

Because of a .csv of data the code depends on, I’m unfortunately unable to presently share a minimal independent code snippet to reproduce this issue. Is there a way I could upload a .csv file?

Edit: I’ll put the .csv up on Github and call the link.

@ptrblck, the .csv file can be read from GitHub, but there are still some modules and the net state_dict that need to be downloaded to run the below reduced code version that still demonstrates the CPU and GPU discrepancy issue. If you don’t have time for the additional downloading, I’ll completely understand. The reduced code is as follows:

#Python-related imports
import math
import sys
from datetime import datetime
import os.path

#Torch imports
import torch
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Function

#PyData imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

#Module imports
from obs_and_flow import *
from plotting import *

#PyTorch settings
if torch.cuda.is_available():
    print('CUDA device detected.')
    active_device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    active_device = torch.device('cpu')
    torch.set_default_tensor_type('torch.FloatTensor')    

print(active_device)

torch.set_printoptions(precision = 8)
torch.manual_seed(0)
#IAF SSM time parameters
dt_flow = 1.0 #Increased from 0.1 to reduce memory.
t = 5000 #In hours.
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

#Training parameters
elbo_iter = 105000
elbo_lr = 1e-2
elbo_lr_decay = 0.7
elbo_lr_decay_step_size = 5000
elbo_warmup_iter = 5000
elbo_warmup_lr = 1e-6
ptrain_iter = 0
ptrain_alg = 'L1'
batch_size = 31
eval_batch_size = 31
obs_error_scale = 0.1
prior_scale_factor = 0.25
num_layers = 5
reverse = False
base_state = False

#Specify desired SBM SDE model type and details.
state_dim_SCON = 3
SBM_SDE_class = 'SCON'
diffusion_type = 'C'
learn_CO2 = False
fix_theta_dict = None

now_string = 'SCON-C_no_CO2_logit_short_2022_01_30_16_14_49_discuss'
outputs_folder = 'training_pt_outputs/'
plots_folder = 'training_plots/'
save_string = '_iter_105000_warmup_5000_t_5000_dt_1.0_batch_31_layers_5_lr_0.01_decay_step_5000_warmup_lr_1e-06_sd_scale_0.25_SCON-C_no_CO2_logit_short_2022_01_30_16_14_49.pt'

net_state_dict_save_string = os.path.join(outputs_folder, 'net_state_dict' + save_string)

obs_times, obs_means, obs_error = csv_to_obs_df('https://raw.githubusercontent.com/wallyxie/varInferenceSoilBiogeoModelSyntheticData/main/pytorch_sbm_sde_vi/generated_data/SCON-C_no_CO2_logit_short_2022_01_20_08_53_sample_y_t_5000_dt_0-01_sd_scale_0-25.csv', state_dim_SCON, t, obs_error_scale)
obs_model = ObsModel(DEVICE = active_device, TIMES = obs_times, DT = dt_flow, MU = obs_means, SCALE = obs_error)
net = SDEFlow(active_device, obs_model, state_dim_SCON, t, dt_flow, n, NUM_LAYERS = num_layers, REVERSE = reverse, BASE_STATE = base_state)
net.load_state_dict(torch.load(net_state_dict_save_string, map_location = active_device))

#Plot training posterior results and ELBO history.
net.eval()
x, _ = net(eval_batch_size)

plot_states_post(x, None, obs_model, None, elbo_iter, elbo_warmup_iter, t, dt_flow, batch_size, eval_batch_size, num_layers, elbo_lr, elbo_lr_decay_step_size, elbo_warmup_lr, prior_scale_factor, plots_folder, now_string, fix_theta_dict, learn_CO2, ymin_list = [0, 0, 0, 0], ymax_list = [70., 8., 11., 0.025])

At this point, the only thing being loaded is the model state_dict after the model is newly instantiated. The state_dict in question can be downloaded at the following link: varInferenceSoilBiogeoModelSyntheticData/net_iter_105000_warmup_5000_t_5000_dt_1.0_batch_31_layers_5_lr_0.01_decay_step_5000_warmup_lr_1e-06_sd_scale_0.25_SCON-C_no_CO2_logit_short_2022_01_30_16_14_49.pt at main · wallyxie/varInferenceSoilBiogeoModelSyntheticData · GitHub

The custom modules that need to be downloaded to the same directory as the script are “obs_and_flow.py” and “plotting.py.”

When launched on a GPU, the script produces the following plot:

When launched on a CPU, the script produces the following plot:

I greatly appreciate your patience, @ptrblck, and again, will completely understand if you don’t have the time to get into downloading of additional custom modules.

Weeks later, I’ve resolved the above issue. I’m following up now with a general description of how the issue circuitously was debugged in case my situation can provide insights for future folks googling about discrepancies between their GPU and CPU testing results.

So, I had some CUDA and CPU tensor disagreements in my code that I initially bypassed inappropriately and lazily with torch.set_default_tensor_type('torch.cuda.FloatTensor'). It became apparent that forcing the default tensor type to torch.cuda.FloatTensor could be problematic when trying net.cpu() after loading the GPU-trained model returned an error. I removed torch.set_default_tensor_type('torch.cuda.FloatTensor') and added the necessary .to(device) calls to restore my code functionality. Afterward, I noticed that net.cpu() was functional and I finally was able to obtain sample values I expected from my model on CPU with the following pseudocode:

net = model(...)
net.load_state_dict(torch.load('state_dict_path.pt', map_location = device))
net.cpu()

For the longest time, I did not realize the issues covered up by torch.set_default_tensor_type('torch.cuda.FloatTensor') were a possible source of the CPU loading issue. Lesson learned – forcing a default tensor type is a potentially problematic way of facing CUDA and CPU tensor mismatch errors!

That sounds scary and like heroic debugging!
When I’ve tried to use the set_default_tensor_type (maybe in PyTorch 0.4?) I’ve ran into some issues where CPU tensors were expected (but I saw runtime errors). Since this is creating a silent error, its usage sounds to be broken.
Would I be able to reproduce it using your model definition and random data?

I did get this issue training the GPU model on other data, so random data should also work for replicating this issue, though the number of necessary pieces could make replicating this a pain.

My model classes:

#Torch-related imports
import torch
from torch.autograd import Function
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim

class LowerBound(Function):
    
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones(inputs.size()).to(inputs.device) * bound
        b = b.type(inputs.dtype)
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b).to(inputs.device)

    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors

        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None

class MaskedConv1d(nn.Conv1d):
    
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv1d, self).__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kW = self.weight.size() # (out_cha, in_cha, kernel_size)
        self.mask.fill_(1)
        self.mask[:, :, kW // 2 + 1 * (mask_type == 'B'):] = 0 # [1, 0, 0] or [1, 1, 0]

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv1d, self).forward(x)

class ResNetBlock(nn.Module):
    
    def __init__(self, inp_cha, out_cha, stride = 1, first = True, batch_norm = True):
        super().__init__()
        self.conv1 = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        self.conv2 = MaskedConv1d('B', out_cha,  out_cha, 3, 1, 1, bias = False)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class ResNetBlockUnMasked(nn.Module):
    
    def __init__(self, inp_cha, out_cha, stride = 1, batch_norm = False):
        super().__init__()
        self.conv1 = nn.Conv1d(inp_cha, out_cha, 3, stride, 1)
        self.conv2 = nn.Conv1d(out_cha, out_cha, 3, 1, 1)
        #in_channels, out_channels, kernel_size, stride=1, padding=0

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = nn.Conv1d(inp_cha,  out_cha, 3, stride, 1, bias=False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class AffineLayer(nn.Module):
    
    def __init__(self, COND_INPUTS, stride, h_cha = 96):
        # COND_INPUTS = COND_INPUTS + obs_dim = 1 + obs_dim = 4 by default (w/o CO2)
        super().__init__()
        self.feature_net = nn.Sequential(ResNetBlockUnMasked(COND_INPUTS, h_cha), ResNetBlockUnMasked(h_cha, COND_INPUTS))
        self.first_block = ResNetBlock(1, h_cha, first = True)
        self.second_block = nn.Sequential(ResNetBlock(h_cha + COND_INPUTS, h_cha, first = False), MaskedConv1d('B', h_cha,  2, 3, stride, 1, bias = False))
        
        self.unpack = True if COND_INPUTS > 1 else False

    def forward(self, x, COND_INPUTS): # x.shape == (batch_size, 1, n * state_dim)
        if self.unpack:
            COND_INPUTS = torch.cat([*COND_INPUTS], 1) # (batch_size, obs_dim + 1, n * state_dim)
        #print(COND_INPUTS[0, :, 0], COND_INPUTS[0, :, 60], COND_INPUTS[0, :, 65])
        COND_INPUTS = self.feature_net(COND_INPUTS) # (batch_size, obs_dim + 1, n * state_dim)
        first_block = self.first_block(x) # (batch_size, h_cha, n * state_dim)
        #print(first_block.shape, COND_INPUTS.shape)
        feature_vec = torch.cat([first_block, COND_INPUTS], 1) # (batch_size, h_cha + obs_dim + 1, n * state_dim)
        output = self.second_block(feature_vec) # (batch_size, 2, n * state_dim)
        mu, sigma = torch.chunk(output, 2, 1) # (batch_size, 1, n * state_dim)
        #print('mu and sigma shapes:', mu.shape, sigma.shape)
        sigma = LowerBound.apply(sigma, 1e-8)
        x = mu + sigma * x # (batch_size, 1, n * state_dim)
        return x, -torch.log(sigma) # each of shape (batch_size, 1, n * state_dim)

class PermutationLayer(nn.Module):
    
    def __init__(self, STATE_DIM, REVERSE = False):
        super().__init__()
        self.state_dim = STATE_DIM
        self.index_1 = torch.randperm(STATE_DIM)
        self.reverse = REVERSE

    def forward(self, x):
        B, S, L = x.shape # (batch_size, 1, state_dim * n)
        x_reshape = x.reshape(B, S, -1, self.state_dim) # (batch_size, 1, n, state_dim)
        if self.reverse:
            x_perm = x_reshape.flip(-2)[:, :, :, self.index_1]
        else:
            x_perm = x_reshape[:, :, :, self.index_1]
        x = x_perm.reshape(B, S, L)
        return x

class SoftplusLayer(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        # in.shape == out.shape == (batch_size, 1, n * state_dim)
        y = self.softplus(x)
        return y, -torch.log(-torch.expm1(-y))

class BatchNormLayer(nn.Module):
    
    def __init__(self, num_inputs, momentum = 1e-2, eps = 1e-5, affine = True):
        super(BatchNormLayer, self).__init__()

        self.log_gamma = nn.Parameter(torch.rand(num_inputs)) if affine else torch.zeros(num_inputs)
        self.beta = nn.Parameter(torch.rand(num_inputs)) if affine else torch.zeros(num_inputs)
        self.momentum = momentum
        self.eps = eps

        self.register_buffer('running_mean', torch.zeros(num_inputs))
        self.register_buffer('running_var', torch.ones(num_inputs))

    def forward(self, inputs):
        inputs = inputs.squeeze(1) # (batch_size, n * state_dim)
        #print(self.training)
        if self.training:
            # Compute mean and var across batch
            self.batch_mean = inputs.mean(0)
            self.batch_var = (inputs - self.batch_mean).pow(2).mean(0) + self.eps

            self.running_mean.mul_(self.momentum)
            self.running_var.mul_(self.momentum)

            self.running_mean.add_(self.batch_mean.data * (1 - self.momentum))
            self.running_var.add_(self.batch_var.data * (1 - self.momentum))

            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var
        # mean.shape == var.shape == (n * state_dim, )

        x_hat = (inputs - mean) / var.sqrt() # (batch_size, n * state_dim)
        #print(mean, var)
        #print('x_hat', x_hat)
        y = torch.exp(self.log_gamma) * x_hat + self.beta # (batch_size, n * state_dim)
        ildj = -self.log_gamma + 0.5 * torch.log(var) # (n * state_dim, )

        # y.shape == (batch_size, 1, n * state_dim), ildj.shape == (1, 1, n * state_dim)
        return y[:, None, :], ildj[None, None, :]
    
class SDEFlow(nn.Module):

    def __init__(self, DEVICE, OBS_MODEL, STATE_DIM, T, DT, N,
                 I_S_TENSOR = None, I_D_TENSOR = None, COND_INPUTS = 1, NUM_LAYERS = 5, POSITIVE = True,
                 REVERSE = False, BASE_STATE = False):
        super().__init__()
        self.device = DEVICE
        self.obs_model = OBS_MODEL
        self.state_dim = STATE_DIM
        self.t = T
        self.dt = DT
        self.n = N

        self.base_state = BASE_STATE
        if self.base_state:
            base_loc_SOC, base_loc_DOC, base_loc_MBC = torch.split(nn.Parameter(torch.zeros(1, self.state_dim)), 1, -1)
            base_scale_SOC, base_scale_DOC, base_scale_MBC = torch.split(nn.Parameter(torch.ones(1, self.state_dim)), 1, -1)
            base_loc = torch.cat((base_loc_SOC.expand([1, self.n]), base_loc_DOC.expand([1, self.n]), base_loc_MBC.expand([1, self.n])), 1)
            base_scale = torch.cat((base_scale_SOC.expand([1, self.n]), base_scale_DOC.expand([1, self.n]), base_scale_MBC.expand([1, self.n])), 1)
            self.base_dist = D.normal.Normal(loc = base_loc, scale = base_scale)                
        else:
            self.base_dist = D.normal.Normal(loc = 0., scale = 1.)

        self.cond_inputs = COND_INPUTS  
        if self.cond_inputs == 3:
            self.i_tensor = torch.stack((I_S_TENSOR.reshape(-1), I_D_TENSOR.reshape(-1)))[None, :, :].repeat_interleave(3, -1)

        self.num_layers = NUM_LAYERS
        self.reverse = REVERSE

        self.affine = nn.ModuleList([AffineLayer(COND_INPUTS + self.obs_model.obs_dim, 1) for _ in range(NUM_LAYERS)])
        self.permutation = [PermutationLayer(STATE_DIM, REVERSE = self.reverse) for _ in range(NUM_LAYERS)]
        self.batch_norm = nn.ModuleList([BatchNormLayer(STATE_DIM * N) for _ in range(NUM_LAYERS - 1)])
        self.positive = POSITIVE
        if self.positive:
            self.SP = SoftplusLayer()
        
    def forward(self, BATCH_SIZE, *args, **kwargs):
        if self.base_state:
            eps = self.base_dist.rsample([BATCH_SIZE]).to(self.device)
        else:
            eps = self.base_dist.rsample([BATCH_SIZE, 1, self.state_dim * self.n]).to(self.device)
        log_prob = self.base_dist.log_prob(eps).sum(-1) # (batch_size, 1)
        
        # NOTE: This currently assumes a regular time gap between observations!
        steps_bw_obs = self.obs_model.idx[1] - self.obs_model.idx[0]
        reps = torch.ones(len(self.obs_model.idx), dtype=torch.long).to(self.device) * self.state_dim
        reps[1:] *= steps_bw_obs
        obs_tile = self.obs_model.mu[None, :, :].repeat_interleave(reps, -1).repeat( \
            BATCH_SIZE, 1, 1).to(self.device) # (batch_size, obs_dim, state_dim * n)
        times = torch.arange(0, self.t + self.dt, self.dt, device = eps.device)[None, None, :].repeat( \
            BATCH_SIZE, self.state_dim, 1).transpose(-2, -1).reshape(BATCH_SIZE, 1, -1).to(self.device)
        
        if self.cond_inputs == 3:
            i_tensor = self.i_tensor.repeat(BATCH_SIZE, 1, 1)
            features = (obs_tile, times, i_tensor)
        else:
            features = (obs_tile, times)
        #print(obs_tile)

        ildjs = []
        
        for i in range(self.num_layers):
            eps, cl_ildj = self.affine[i](self.permutation[i](eps), features) # (batch_size, 1, n * state_dim)
            #print('Coupling layer {}'.format(i), eps, cl_ildj)
            if i < (self.num_layers - 1):
                eps, bn_ildj = self.batch_norm[i](eps) # (batch_size, 1, n * state_dim), (1, 1, n * state_dim)
                ildjs.append(bn_ildj)
                #print('BatchNorm layer {}'.format(i), eps, bn_ildj)
            ildjs.append(cl_ildj)
                
        if self.positive:
            eps, sp_ildj = self.SP(eps) # (batch_size, 1, n * state_dim)
            ildjs.append(sp_ildj)
            #print('Softplus layer', eps, sp_ildj)
        
        eps = eps.reshape(BATCH_SIZE, -1, self.state_dim) + 1e-6 # (batch_size, n, state_dim)
        for ildj in ildjs:
            log_prob += ildj.sum(-1) # (batch_size, 1)
    
        #return eps.reshape(BATCH_SIZE, self.state_dim, -1).permute(0, 2, 1) + 1e-6, log_prob
        return eps, log_prob # (batch_size, n, state_dim), (batch_size, 1)

class ObsModel(nn.Module):

    def __init__(self, DEVICE, TIMES, DT, MU, SCALE):
        super().__init__()
        self.times = TIMES # (n_obs, )
        self.dt = DT
        self.idx = self.get_idx(TIMES, DT)        
        self.mu = torch.Tensor(MU).to(DEVICE) # (obs_dim, n_obs)
        self.scale = torch.Tensor(SCALE).to(DEVICE) # (1, obs_dim)
        self.obs_dim = self.mu.shape[0]
        
    def forward(self, x):
        obs_ll = D.normal.Normal(self.mu.permute(1, 0), self.scale).log_prob(x[:, self.idx, :])
        return torch.sum(obs_ll, [-1, -2]).mean()

    def get_idx(self, TIMES, DT):
        return list((TIMES / DT).astype(int))
    
    def plt_dat(self):
        return self.mu, self.times

Before instantiating the net, we need to have an observation and state dimension in mind. We’ll go with 3. So, an observation model needs to be instantiated with:

import numpy as np #Data I was working with was in Numpy array format.

device = torch.device('cpu')
obs_means = np.random.rand(3, 101)
obs_times = np.arange(0, 101)
obs_error = 0.1 * obs_means
dt = 1.
obs_model = ObsModel(DEVICE = device, TIMES = obs_times, DT = dt, MU = obs_means, SCALE = obs_error).to(device)

Then, we can instantiate the net with

t = 100
n = int(t / dt) + 1
net = SDEFlow(device, obs_model, 3, t, dt, n, NUM_LAYERS = 1, REVERSE = True, BASE_STATE = False).to(device)

Then, you can train with something along the lines of

iterations = 1000
training_batch_size = 20
loss_opt = optim.Adam(net.parameters(), lr = 1e-3)
loss = 1e20
for i in range(iterations):
    loss_opt.zero_grad()
    x, log_prob = net(training_batch_size)
    loss = log_prob.mean() - obs_model(x)
    loss.backward()

and then the requisite net saving and loading thereafter to plot on CPU.


By the way, I have noticed that I have not managed to get the CPU loading working with net.eval() turned on, so that’s probably a sign that there’s still something wrong.

1 Like