I decided to make my code faster by removing python lists and storing all my tensors into a tuple. Before doing this, my code worked but my new code is getting this inplace operation error. It’s somewhere in my forward function as that’s the only code I have edited prior to the working code. I have tried doing different combinations of clones, but I just cannot seem to find the issue.
def forward(self):
x = self.data
log_det_j = x.new_zeros(x.shape[0])
for i in range(self.coup_layer):
theta = []
if self.flip[i] == 1:
x1, x2 = torch.chunk(x, 2, dim=1)
else:
x2, x1 = torch.chunk(x, 2, dim=1)
one = x.new_ones((len(x2[1]), x.shape[0], 2))
theta_x = torch.zeros((len(x2[1]), len(x2), self.k_bins))
theta_fx = torch.zeros((len(x2[1]), len(x2), self.k_bins))
theta_d = torch.zeros((len(x2[1]), len(x2), self.k_bins - 1))
theta.extend((theta_x, theta_fx, theta_d))
x1_ = self.layer[i](x1)
# Partition the data by its theta parameters
for dim, chunk in enumerate(torch.chunk(x1_, len(x2[1]), dim=1)):
for j, val in enumerate(torch.chunk(chunk, 3, dim=1)):
theta[j][dim] = val
#Convert the parameters into cumulative sum of softmax parameters
# theta[0] = x_i, theta[1] = f(x)_i, theta[2] = d_i
for j in range(len(x2[1])):
theta[0][j] = torch.cumsum(F.softmax(theta[0][j], dim=1), dim=1) * 2 * self.b - self.b #x_i
theta[1][j] = torch.cumsum(F.softmax(theta[1][j], dim=1), dim=1) * 2 * self.b - self.b #f(x)_i
theta[2][j] = F.softplus(theta[2][j]) #d_i
# Add in endpoint
theta_ = torch.zeros((3, len(x2[1]), len(x2), self.k_bins + 1))
theta_[0, :, :, 1:self.k_bins+1] = theta[0]
theta_[1, :, :, 1:self.k_bins+1] = theta[1]
theta_[2, :, :, 1:self.k_bins] = theta[2]
theta_[2, :, :, (0, self.k_bins)] = one
#Search for bin location and update variables correspondingly
for j in range(len(x2[1])):
# Contains index of i and i+1
x_i_id = F.relu(theta_[0, j].clone() - x2[:, j].reshape(x2.size()[0], 1)).clone()
x_i_id = torch.argmin(x_i_id, dim=1)
x_i_plus_id = x_i_id + 1
x_i_ = theta_[0, j].gather(1, x_i_id.view(-1, 1)).clone()
x_i_plus_ = theta_[0, j].gather(1, x_i_plus_id.view(-1, 1)).clone()
y_i_ = theta_[1, j].gather(1, x_i_id.view(-1, 1)).clone()
y_i_plus_ = theta_[1, j].gather(1, x_i_plus_id.view(-1, 1)).clone()
d_i_ = theta_[2, j].gather(1, x_i_id.view(-1, 1)).clone()
d_i_plus_ = theta_[2, j].gather(1, x_i_plus_id.view(-1, 1)).clone()
if j == 0:
x_i = x_i_
x_i_plus = x_i_plus_
y_i = y_i_
y_i_plus = y_i_plus_
d_i = d_i_
d_i_plus = d_i_plus_
else:
x_i = torch.cat((x_i, x_i_), dim=1)
x_i_plus = torch.cat((x_i_plus, x_i_plus_), dim=1)
y_i = torch.cat((y_i, y_i_), dim=1)
y_i_plus = torch.cat((y_i_plus, y_i_plus_), dim=1)
d_i = torch.cat((d_i, d_i_), dim=1)
d_i_plus = torch.cat((d_i_plus, d_i_plus_), dim=1)
h_i = x_i_plus - x_i
x_rel = (x2 - x_i) / h_i
delta = (y_i_plus - y_i) / h_i
x2 = y_i + ((y_i_plus - y_i) * (delta * x_rel ** 2 + d_i * x_rel * (1 - x_rel))
/ (delta + (d_i_plus + d_i - 2 * delta) * x_rel * (1 - x_rel)))
Jac = ((delta ** 2) * (d_i_plus * (x_rel ** 2) + 2 * delta * x_rel * (1 - x_rel) + d_i * ((1 - x_rel) ** 2))
/ (delta + (d_i_plus + d_i - 2 * delta) * x_rel * (1 - x_rel)) ** 2)
Jac = torch.log(Jac)
if self.flip[i] == 1:
x = torch.cat((x1, x2), dim=1)
else:
x = torch.cat((x2, x1), dim=1)
log_det_j += Jac.sum(dim=1)
return x, log_det_j