I want trace my own model. I want to use it in C++. But when I use torch.jit.trace to trace it, I got errors as follows:
Traceback (most recent call last): File “main.py”, line 457, in traced_script_module = torch.jit.trace(model, (example_input,example_rawdata,example_index,example_mask) ) File “D:\Anaconda3\envs\GDN\lib\site-packages\torch\jit_init_.py”, line 875, in trace check_tolerance, force_outplace, *module_class) File "D:\Anaconda3\envs\GDN\lib\site-packages\torch\jitinit*.py", line 1037, in trace_module check_tolerance, force_outplace, True, *module_class) File “D:\Anaconda3\envs\GDN\lib\site-packages\torch\autograd\grad_mode.py”, line 15, in decorate_context return func(*args, **kwargs) File "D:\Anaconda3\envs\GDN\lib\site-packages\torch\jitinit*.py", line 675, in _check_trace raise TracingCheckError(*diag_info) torch.jit.TracingCheckError: Tracing failed sanity checks! ERROR: Graphs differed across invocations!
And the log is too long, the last sentences are:
First diverging operator:
Node diff:
- %5 : __torch__.model.GG_Nets.OutLayer = prim::GetAttr[name="out_layer"](%self.1)
+ %5 : __torch__.model.GG_Nets.___torch_mangle_50.OutLayer = prim::GetAttr[name="out_layer"](%self.1)
? +++++++++++++++++++
My Code for trace:
main = Main(train_config, env_config, debug=False)
model = main.model.eval()
model = model.float()
example_input = torch.rand(1,6, 850).to('cuda')
example_rawdata = torch.rand(1,6,1000).to('cuda')
example_index = torch.rand(1,2,30).to('cuda')
example_mask = torch.rand(1,1,6,1000).to('cuda')
traced_script_module = torch.jit.trace(model, (example_input,example_rawdata,example_index,example_mask) )
My model code:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from util.time import *
from util.env import *
from torch_geometric.nn import GCNConv, GATConv, EdgeConv
import math
import torch.nn.functional as F
from pyinform.transferentropy import transfer_entropy
import seaborn as sns
from .graph_layer import GraphLayer
def get_batch_edge_index(org_edge_index, batch_num, node_num):
# org_edge_index:(2, edge_num)
edge_index = org_edge_index.clone().detach()
edge_num = org_edge_index.shape[1]
batch_edge_index = edge_index.repeat(1, batch_num).contiguous()
for i in range(batch_num):
batch_edge_index[:, i * edge_num:(i + 1) * edge_num] += i * node_num
return batch_edge_index
class Discrimanitor(nn.Module):
def __init__(self):
super(Discrimanitor, self).__init__()
self.lin1 = nn.Linear(1000,500)
self.bn1 = nn.BatchNorm1d(500)
self.lin2 = nn.Linear(500,100)
self.bn2 = nn.BatchNorm1d(100)
self.lin3 = nn.Linear(100,10)
self.bn3 = nn.BatchNorm1d(10)
self.lin4 = nn.Linear(10,1)
self.sig = nn.Sigmoid()
self.conv1 = nn.Conv1d(in_channels=1, # batch , 4,402
self.bn1 = nn.BatchNorm1d(num_features=4, ) # batch , 4 ,402
self.conv2 = nn.Conv1d(in_channels=4, # batch , 8 ,101
self.bn2 = nn.BatchNorm1d(num_features=8, )
self.conv3 = nn.Conv1d(in_channels=8, # batch, 16, 50
self.bn3 = nn.BatchNorm1d(num_features=4, )
self.conv4 = nn.Conv1d(in_channels=4, # batch, 16, 50
self.fc = nn.Linear(123,1)
self.bn4 = nn.BatchNorm1d(num_features=1, )
self.sig = nn.Sigmoid()
def forward(self,x,x_gen):
batch_num, node_num, all_feature = x.shape
D_out1 = self.lin1(x_gen)
D_out1 = self.bn1(D_out1)
D_out2 = self.lin2(D_out1)
D_out2 = self.bn2(D_out2)
D_out3 = self.lin3(D_out2)
D_out3 = self.bn3(D_out3)
D_out4 = self.lin4(D_out3)
D_out5 = self.sig(D_out4)
D_out = self.conv1(x_gen)
D_out = self.bn1(D_out)
D_out = self.conv2(D_out)
D_out = self.bn2(D_out)
D_out = self.conv3(D_out)
D_out = self.bn3(D_out)
D_out = self.conv4(D_out)
D_out = self.fc(D_out)
D_out = self.sig(D_out)
D_out = torch.reshape(D_out,(batch_num,node_num,-1))
return D_out
class dcnn(nn.Module):
def __init__(self):
super(dcnn, self).__init__()
self.con1 = nn.Conv1d(in_channels=6,
self.bn1 = nn.BatchNorm1d(num_features=18, )
self.conv2 = nn.Conv1d(in_channels=18, # batch , 8 ,101
self.bn2 = nn.BatchNorm1d(num_features=6, )
self.line = nn.Linear(213,250)
self.relu = nn.LeakyReLU(inplace=True)
def forward (self,x):
x = self.con1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.line(x)
x = self.relu(x)
return x
class attention_layer(nn.Module):
def __init__(self,hidden_dim):
super(attention_layer, self).__init__()
self.q = nn.Linear(hidden_dim,hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.lin2 = nn.Linear(hidden_dim, 125)
def forward(self,x):
Q = self.q(x)
K = self.k(x).permute(0, 2, 1)
V = self.v(x)
alpha = torch.matmul(Q,K)
alpha = F.softmax(alpha, dim=2)
out = torch.matmul(alpha, V)
out = out.reshape(-1,6,425)
out = self.lin2(out)
return out,alpha
class Generator(nn.Module):
def __init__(self, singal_num, node_num):
super(Generator, self).__init__() # input 18,250
self.input_channel = singal_num
#self.mask = mask
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#self.embeddings = embedding(torch.arange(node_num))
self.deconv1 = nn.ConvTranspose1d(in_channels=self.input_channel,
out_channels=self.input_channel * 4, # 3
# output_padding=1,
self.bn1 = nn.BatchNorm1d(num_features=self.input_channel * 4, )
self.deconv2 = nn.ConvTranspose1d(in_channels=self.input_channel * 4,
out_channels=self.input_channel * 8, # 5
self.bn2 = nn.BatchNorm1d(num_features=self.input_channel * 8, )
self.deconv3 = nn.ConvTranspose1d(in_channels=self.input_channel * 8,
out_channels=self.input_channel, # 9
# self.bn3 = nn.BatchNorm1d(num_features=self.conv_channel_size, )
self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.l1 = nn.Linear(1000,1000)
def forward(self, x,mask,embeddings):
batch_num, node_num, all_feature = x.shape
all_embeddings = embeddings
all_embeddings = all_embeddings.reshape(1,all_embeddings.shape[0],all_embeddings.shape[1])
all_embeddings = all_embeddings.repeat(batch_num, 1,1).to(self.device)
noise = torch.randn_like(all_embeddings)
all_embeddings = all_embeddings + noise
h = self.deconv1(all_embeddings)
h = self.bn1(h)
h = self.relu(h)
h= self.deconv2(h)
h = self.bn2(h)
h = self.relu(h)
h = self.deconv3(h)
#h = self.relu(h)
h= self.l1(h)
h = self.tanh(h)
#print("hshape", h.shape)
mask_b = torch.ones_like(mask1.cpu())
#mask_b = torch.from_numpy(mask_b)
x_gen = torch.mul(mask1,x).add(torch.mul(mask_b.sub(mask1.cpu()).to(self.device),h))
x_ = x_gen[:,:,:850]
y_ = x_gen[:,:,850:]
print(((torch.from_numpy(np.ones_like(mask) - mask))*h).shape)
x_gen = mask*h + (torch.from_numpy(np.ones_like(mask) - mask))*h
return x_,x_gen,y_,h
class OutLayer(nn.Module):
def __init__(self, in_num, node_num, layer_num, inter_num=512, out_num=100):
super(OutLayer, self).__init__()
modules = []
for i in range(layer_num):
# last layer, output shape:1
if i == layer_num - 1:
modules.append((nn.Linear(in_num if layer_num == 1 else inter_num, out_num)))
layer_in_num = in_num if i == 0 else inter_num
modules.append(nn.Linear(layer_in_num, inter_num))
self.mlp = nn.ModuleList(modules)
def forward(self, x):
out = x
# print(out.shape)
for mod in self.mlp:
if isinstance(mod, nn.BatchNorm1d):
out1 = out.permute(0, 2, 1)
out2 = mod(out1)
out = out2.permute(0, 2, 1)
out = mod(out)
# print(out.shape)
return out
class GNNLayer(nn.Module):
def __init__(self, in_channel, out_channel, inter_dim=0, heads=1, node_num=100):
super(GNNLayer, self).__init__()
self.gnn = GraphLayer(in_channel, out_channel, inter_dim=inter_dim, heads=heads, concat=False)
self.bn = nn.BatchNorm1d(out_channel)
self.relu = nn.ReLU()
self.leaky_relu = nn.LeakyReLU()
def forward(self, x, edge_index, embedding=None, node_num=0):
out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
self.att_weight_1 = att_weight
self.edge_index_1 = new_edge_index
out = self.bn(out)
return self.relu(out) ,self.att_weight_1 ,self.edge_index_1
class TemporalAttentionLayer(nn.Module):
def __init__(self, n_features, window_size, dropout, alpha, embed_dim=None, use_bias=True):
super(TemporalAttentionLayer, self).__init__()
self.n_features = n_features
self.window_size = window_size
self.dropout = dropout
self.embed_dim = embed_dim if embed_dim is not None else n_features
self.num_nodes = window_size
self.use_bias = use_bias
# Because linear transformation is performed after concatenation in GATv2
lin_input_dim = n_features
a_input_dim = 2 * self.embed_dim
self.lin = nn.Linear(lin_input_dim, self.embed_dim)
self.a = nn.Parameter(torch.empty((a_input_dim, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
if self.use_bias:
self.bias = nn.Parameter(torch.empty(window_size, window_size))
self.leakyrelu = nn.LeakyReLU(alpha)
self.relu =nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.lin2 = nn.Linear(window_size,125)
def forward(self, x):
# x shape (b, n, k): b - batch size, n - window size, k - number of features
# For temporal attention a node is represented as all feature values at a specific timestamp
# 'Dynamic' GAT attention
# Proposed by Brody et. al., 2021 (https://arxiv.org/pdf/2105.14491.pdf)
# Linear transformation applied after concatenation and attention layer applied after leakyrelu
# Original GAT attention
Wx = self.lin(x) # (b, n, n, embed_dim)
a_input = self._make_attention_input(Wx) # (b, n, n, 2*embed_dim)
e = self.leakyrelu(torch.matmul(a_input, self.a)).squeeze(3) # (b, n, n, 1)
if self.use_bias:
e += self.bias # (b, n, n, 1)
# Attention weights
#attention = e
attention = torch.softmax(e, dim=2)
"""for i in range(attention.shape[0]):
for j in range(attention.shape[1]):
for k in range(attention.shape[2]):
if attention[i,j,k]<0.01:
#attention = (self.relu(attention - 0.001) * attention) \
# / (torch.abs(attention - 0.001) + 0.0000001)
#attention = torch.softmax(attention, dim=2)
#attention = attention / attention.norm(p=1, dim=0)
#attention = torch.dropout(attention, self.dropout, train=self.training)
h = torch.matmul(attention, x) + x # (b, n, k)
h = h.permute(0, 2, 1)
h = self.lin2(h)
h = self.relu(h)
return h,attention
def _make_attention_input(self, v):
K = self.num_nodes
# print(K)
# print("v",v.shape)
blocks_repeating = v.repeat_interleave(K, dim=1) # Left-side of the matrix
# print(blocks_repeating.shape)
blocks_alternating = v.repeat(1, K, 1) # Right-side of the matrix
# print(blocks_alternating.shape)
combined = torch.cat((blocks_repeating, blocks_alternating), dim=2)
return combined.view(v.size(0), K, K, 2 * self.embed_dim)
class GDN(nn.Module):
def __init__(self, edge_index_sets, node_num, dim=64, out_layer_inter_dim=256, input_dim=10, out_layer_num=1,
topk=20, predict_num=100):
super(GDN, self).__init__()
self.predict_num = predict_num
self.edge_index_sets = edge_index_sets
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
edge_index = edge_index_sets[0]
embed_dim = dim
self.embedding = nn.Embedding(node_num, embed_dim)
self.bn_outlayer_in = nn.BatchNorm1d(embed_dim)
edge_set_num = len(edge_index_sets)
self.gnn_layers = nn.ModuleList([
GNNLayer(250, dim, inter_dim=dim + embed_dim, heads=1) for i in range(edge_set_num)
self.cosine_similarity = nn.CosineSimilarity(dim=2,)
self.mem_num =90
init_mem = torch.zeros(self.mem_num,250*6)
self.memory = nn.Parameter(init_mem)
self.node_embedding = None
self.topk = topk
self.learned_graph = None
self.out_layer = OutLayer(dim * edge_set_num, node_num * 10, out_layer_num, inter_num=out_layer_inter_dim,
self.t_gats = TemporalAttentionLayer(6, 425, 0.2, 0.1, embed_dim=None, use_bias=False)
self.att = attention_layer(425)
#self.line = nn.Linear(850,250)
#self.cnn = dcnn()
self.cache_edge_index_sets = [None] * edge_set_num
self.cache_embed_index = None
self.dp = nn.Dropout(0.2)
self.generator = Generator(node_num, node_num).to(device)
#self.discriminator = Discrimanitor()
def init_params(self):
nn.init.kaiming_uniform_(self.embedding.weight, a=math.sqrt(5))
def forward(self, data, raw_data, org_edge_index,mask):
x = data.clone().detach()
batch_num, node_num, all_feature = x.shape
device = data.device
all_embedding = self.embedding(torch.arange(node_num).to(device))
generator,x_,y_gen,g_sample = self.generator(raw_data,mask,all_embedding)
#d_out = self.discriminator(raw_data,x)
#x_ = raw_data[:,:,:850]
#y_gen = raw_data[:,:,850:]
edge_index_sets = self.edge_index_sets
# print("edge_index_set",edge_index_sets)
#x = generator.view(-1, all_feature).contiguous()
x__ = generator.reshape(batch_num,all_feature,node_num)
b = torch.linspace(0, 848, 425).long()
c = torch.linspace(1, 849, 425).long()
x_ji =x__[:,b,:].reshape(batch_num,6,-1)
x_ou = x__[:,c,:].reshape(batch_num,6,-1)
#x = x.reshape(-1,250)
for i in range(x__.shape[2]):
x = x__[:,:,i]
t_a = self.t_gats[i](x.unsqueeze(2))
if i==0:
t_ini = t_a;
t_a,attentionji = self.t_gats(x_ji)
t_a2,attentionou = self.t_gats(x_ou)
t_a, attentionji = self.att(x_ji)
t_a2, attentionou = self.att(x_ou)
tout = torch.cat((t_a,t_a2),dim=2)
#x = tout.reshape(-1, 250)
x = tout.reshape(-1, 250)
#x = self.cnn(x)
#x = x.reshape(-1, 250)
gcn_outs = []
for i, edge_index in enumerate(edge_index_sets):
edge_num = edge_index.shape[1]
cache_edge_index = self.cache_edge_index_sets[i]
if cache_edge_index is None or cache_edge_index.shape[1] != edge_num * batch_num:
self.cache_edge_index_sets[i] = get_batch_edge_index(edge_index, batch_num, node_num).to(device)
batch_edge_index = self.cache_edge_index_sets[i]
all_embeddings = self.embedding(torch.arange(node_num).to(device))
weights_arr = all_embeddings.detach().clone()
all_embeddings = all_embeddings.repeat(batch_num, 1)
weights = weights_arr.view(node_num, -1)
cos_ji_mat = torch.matmul(weights, weights.T)
normed_mat = torch.matmul(weights.norm(dim=-1).view(-1, 1), weights.norm(dim=-1).view(1, -1))
cos_ji_mat = cos_ji_mat / normed_mat
dim = weights.shape[-1]
topk_num = self.topk
topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1]
self.learned_graph = topk_indices_ji
gated_i = torch.arange(0, node_num).T.unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0)
gated_j = topk_indices_ji.flatten().unsqueeze(0)
gated_edge_index = torch.cat((gated_j, gated_i), dim=0)
batch_gated_edge_index = get_batch_edge_index(gated_edge_index, batch_num, node_num).to(device)
gcn_out,att,edg=self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num * batch_num,embedding=all_embeddings)
x = torch.cat(gcn_outs, dim=1)
x = x.view(batch_num, node_num, -1)
x = x+tout
ex_mem = self.memory.unsqueeze(0).repeat(batch_num, 1, 1)
ex_z = x.flatten(-2).unsqueeze(1).repeat(1, self.mem_num, 1)
mem_logit = self.cosine_similarity(ex_z, ex_mem)
mem_weight = F.softmax(mem_logit, dim=1)
z_hat = torch.matmul(mem_weight, self.memory)
z_hat = z_hat.reshape(batch_num,6,250)
indexes = torch.arange(0, node_num).to(device)
z_hat = z_hat + self.embedding(indexes)
#out = torch.mul(x, z_hat)
out = x + z_hat
out1 = out.permute(0, 2, 1)
out2 = F.relu(self.bn_outlayer_in(out1))
out3 = out2.permute(0, 2, 1)
out4 = self.dp(out3)
out5 = self.out_layer(out4)
out6 = out5.view(-1, node_num, self.predict_num)
return out6,y_gen,x_,g_sample,x,self.memory,mem_weight