RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [200, 2]], which is output 0 of AsStridedBackward0, is at version 1401; expected version 1400 instead

Hi there!
When trying to do a backward-pass in my model I receive the following error:

Things I already tried:

  • inplace=True for Relu()
  • clone() on Tensors in GaussianMLP
  • trying to replace all python inplace operations

@ptrblck I see that you helped a lot on other issues like this. Am I missing here something?

z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = CLUE_explainer.optimise(
        min_steps=3, max_steps=35,
        n_early_stop=3)

clue.py

torch.autograd.set_detect_anomaly(True)

class CLUE(BaseNet):
    def __init__(self, VAE, BNN, original_x, uncertainty_weight, aleatoric_weight, epistemic_weight, prior_weight, distance_weight,
                 latent_L2_weight, prediction_similarity_weight,
                 lr, desired_preds=None, cond_mask=None, distance_metric=None, z_init=None, norm_MNIST=False, flatten_BNN=False,
                 regression=False, prob_BNN=True, cuda=True):

        # remove constructor for minimalizm
        self.original_x = torch.Tensor(original_x)

        self.prob_BNN = prob_BNN
        self.cuda = cuda
        if self.cuda:
            self.original_x = self.original_x.cuda()
            if self.desired_preds is not None:
                self.desired_preds = self.desired_preds.cuda()
        self.cond_mask = cond_mask

        self.trainable_params = list()

        if self.VAE is None:
            self.trainable_params.append(nn.Parameter(original_x))
        else:
            self.z_dim = VAE.latent_dim
            if z_init is not None:
                self.z_init = torch.Tensor(z_init)
                if cuda:
                    self.z_init = self.z_init.cuda()
                self.z = nn.Parameter(self.z_init)
                self.trainable_params.append(self.z)
            else:
                self.z_init = torch.zeros(self.z_dim).unsqueeze(0).repeat(original_x.shape[0], 1)
                if self.cuda:
                    self.z_init = self.z_init.cuda()
                self.z = nn.Parameter(self.z_init)
                self.trainable_params.append(self.z)

        self.optimizer = Adam(self.trainable_params, lr=lr)

    def randomise_z_init(self, std):
        eps = torch.randn(self.z.shape).type(self.z.type())
        self.z.data = std * eps + self.z_init
        return None

    def pred_dist(self, preds):
        assert self.desired_preds is not None

        if self.regression:
            dist = F.mse_loss(preds, self.desired_preds, reduction='none')
        else:

            if len(self.desired_preds.shape) == 1 or self.desired_preds.shape[1] == 1:
                dist = F.nll_loss(preds, self.desired_preds, reduction='none')
            else:
                dist = -(torch.log(preds) * self.desired_preds).sum(dim=1)
        return dist

    def uncertainty_from_z(self):
        x = self.VAE.regenerate(self.z, grad=True)

        if self.flatten_BNN:
            to_BNN = x.view(x.shape[0], -1)
        else:
            to_BNN = x

        if self.norm_MNIST:
            to_BNN = calculate_mnist_mean_std_norm(to_BNN)

        if self.prob_BNN:
            if self.regression:
                mu_vec, std_vec = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
                total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty = decompose_std_gauss(mu_vec, std_vec)
                preds = mu_vec.mean(dim=0)
            else:
                probs = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
                total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty = decompose_entropy_cat(probs)
                preds = probs.mean(dim=0)
        else:
            if self.regression:
                mu, std = self.BNN.predict(to_BNN, grad=True)
                total_uncertainty = std.squeeze(1)
                aleatoric_uncertainty = total_uncertainty
                epistemic_uncertainty = total_uncertainty * 0
                preds = mu
            else:
                probs = self.BNN.predict(to_BNN, grad=True)
                total_uncertainty = -(probs * torch.log(probs + 1e-10)).sum(dim=1, keepdim=False)
                aleatoric_uncertainty = total_uncertainty
                epistemic_uncertainty = total_uncertainty * 0
                preds = probs

        return total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds

    def get_objective(self, x, total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, preds):
        objective = self.uncertainty_weight * total_uncertainty + self.aleatoric_weight * aleatoric_uncertainty + \
                    self.epistemic_weight * epistemic_uncertainty

        if self.VAE is not None and self.cond_mask is None and self.prior_weight > 0:
            try:
                prior_loglike = self.VAE.prior.log_prob(self.z).sum(dim=1)
            except:
                prior_loglike = self.VAEAC.get_prior(self.original_x, self.cond_mask, flatten=False).log_prob(self.z).sum(dim=1)
            objective = objective + self.prior_weight * prior_loglike

        if self.latent_L2_weight != 0 and self.latent_L2_weight is not None:
            latent_dist = F.mse_loss(self.z, self.z_init, reduction='none').view(x.shape[0], -1).sum(dim=1)
            objective = objective + self.latent_L2_weight * latent_dist

        if self.desired_preds is not None:
            pred_dist = self.pred_dist(preds).view(preds.shape[0], -1).sum(dim=1)
            objective = objective + self.prediction_similarity_weight * pred_dist

        if self.distance_metric is not None:
            dist = self.distance_metric(x, self.original_x).view(x.shape[0], -1).sum(dim=1)
            objective = objective + self.distance_weight * dist

            return objective, self.distance_weight * dist
        else:
            return objective, 0

    def optimise(self, min_steps=3, max_steps=25,
                 n_early_stop=3):
        z_vec = [self.z.data.cpu().numpy()]
        x_vec = []
        uncertainty_vec = np.zeros((max_steps, self.z.shape[0]))
        aleatoric_vec = np.zeros((max_steps, self.z.shape[0]))
        epistemic_vec = np.zeros((max_steps, self.z.shape[0]))
        dist_vec = np.zeros((max_steps, self.z.shape[0]))
        cost_vec = np.zeros((max_steps, self.z.shape[0])) 

        it_mask = np.zeros(self.z.shape[0])

        for step_idx in range(max_steps):
            self.optimizer.zero_grad()
            total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds = self.uncertainty_from_z()
            objective, w_dist = self.get_objective(x, total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, preds)

            objective.sum(dim=0).backward()

            self.optimizer.step()

            uncertainty_vec[step_idx, :] = total_uncertainty.data.cpu().numpy()
            aleatoric_vec[step_idx, :] = aleatoric_uncertainty.data.cpu().numpy()
            epistemic_vec[step_idx, :] = epistemic_uncertainty.data.cpu().numpy()
            dist_vec[step_idx, :] = (w_dist.data.cpu().numpy())
            cost_vec[step_idx, :] = (objective.data.cpu().numpy())
            x_vec.append(x.detach().cpu().numpy()) 
            z_vec.append(self.z.detach().cpu().numpy())

            it_mask = CLUE.update_stopvec(cost_vec, it_mask, step_idx, n_early_stop, min_steps)

        x = self.VAE.regenerate(self.z, grad=False).data
        x_vec.append(x)
        x_vec = [i.cpu().numpy() for i in x_vec]
        x_vec = np.stack(x_vec)
        z_vec = np.stack(z_vec)

        uncertainty_vec, epistemic_vec, aleatoric_vec, dist_vec, cost_vec, z_vec, x_vec = CLUE.apply_stopvec(it_mask,
                                                                                                             uncertainty_vec, epistemic_vec,
                                                                                                             aleatoric_vec, dist_vec, cost_vec, z_vec,
                                                                                                             x_vec,
                                                                                                             n_early_stop)
        return z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec

    @staticmethod
    def update_stopvec(cost_vec, it_mask, step_idx, n_early_stop, min_steps):
        asymptotic_rel = np.abs(cost_vec[step_idx - n_early_stop, :] - cost_vec[step_idx, :]) < cost_vec[0, :] * 1e-2
        asymptotic_abs = np.abs(cost_vec[step_idx - n_early_stop, :] - cost_vec[step_idx, :]) < 1e-3

        if step_idx > min_steps:
            condition_sum = asymptotic_rel + asymptotic_abs
        else:
            condition_sum = np.array([0])

        stop_vec = condition_sum.clip(max=1, min=0)

        to_mask = (it_mask == 0).astype(int) * stop_vec
        it_mask[to_mask == 1] = step_idx

        if (it_mask == 0).sum() == 0 and n_early_stop > 0:
            print('it %d, all conditions met, stopping' % step_idx)
        return it_mask

    @staticmethod
    def apply_stopvec(it_mask, uncertainty_vec, epistemic_vec, aleatoric_vec, dist_vec, cost_vec, z_vec, x_vec, n_early_stop):
        it_mask = (it_mask - n_early_stop + 1).astype(int)
        for i in range(uncertainty_vec.shape[1]):
            if it_mask[i] > 0 and n_early_stop > 0:
                uncertainty_vec[it_mask[i]:, i] = uncertainty_vec[it_mask[i], i]
                epistemic_vec[it_mask[i]:, i] = epistemic_vec[it_mask[i], i]
                aleatoric_vec[it_mask[i]:, i] = aleatoric_vec[it_mask[i], i]
                cost_vec[it_mask[i]:, i] = cost_vec[it_mask[i], i]
                dist_vec[it_mask[i]:, i] = dist_vec[it_mask[i], i]
                z_vec[it_mask[i]:, i] = z_vec[it_mask[i], i]
                x_vec[it_mask[i]:, i] = x_vec[it_mask[i], i]
        return uncertainty_vec, epistemic_vec, aleatoric_vec, dist_vec, cost_vec, z_vec, x_vec

    def sample_explanations(self, n_explanations, init_std=0.15, min_steps=3, max_steps=25,
                            n_early_stop=3):
        full_x_vec = []
        full_z_vec = []
        full_uncertainty_vec = []
        full_aleatoric_vec = []
        full_epistemic_vec = []
        full_dist_vec = []
        full_cost_vec = []

        for i in range(n_explanations):
            self.randomise_z_init(std=init_std)

            torch.autograd.set_detect_anomaly(False)

            z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = self.optimise(
                min_steps=min_steps, max_steps=max_steps,
                n_early_stop=n_early_stop)

            full_x_vec.append(x_vec)
            full_z_vec.append(z_vec)
            full_uncertainty_vec.append(uncertainty_vec)
            full_aleatoric_vec.append(aleatoric_vec)
            full_epistemic_vec.append(epistemic_vec)
            full_dist_vec.append(dist_vec)
            full_cost_vec.append(cost_vec)

        full_x_vec = np.concatenate(np.expand_dims(full_x_vec, axis=0), axis=0)
        full_z_vec = np.concatenate(np.expand_dims(full_z_vec, axis=0), axis=0)
        full_cost_vec = np.concatenate(np.expand_dims(full_cost_vec, axis=0), axis=0)
        full_dist_vec = np.concatenate(np.expand_dims(full_dist_vec, axis=0), axis=0)
        full_uncertainty_vec = np.concatenate(np.expand_dims(full_uncertainty_vec, axis=0), axis=0)
        full_aleatoric_vec = np.concatenate(np.expand_dims(full_aleatoric_vec, axis=0), axis=0)
        full_epistemic_vec = np.concatenate(np.expand_dims(full_epistemic_vec, axis=0), axis=0)

        return full_x_vec, full_z_vec, full_uncertainty_vec, full_aleatoric_vec, full_epistemic_vec, full_dist_vec, full_cost_vec

    @classmethod
    def batch_optimise(cls, VAE, BNN, original_x, uncertainty_weight, aleatoric_weight, epistemic_weight, prior_weight,
                       distance_weight, latent_L2_weight, prediction_similarity_weight, lr, min_steps=3, max_steps=25,
                       n_early_stop=3, batch_size=256, cond_mask=None, desired_preds=None,
                       distance_metric=None, z_init=None, norm_MNIST=False, flatten_BNN=False, regression=False,
                       prob_BNN=True, cuda=True):
        full_x_vec = []
        full_z_vec = []
        full_uncertainty_vec = []
        full_aleatoric_vec = []
        full_epistemic_vec = []
        full_dist_vec = []
        full_cost_vec = []

        idx_iterator = generate_ind_batch(original_x.shape[0], batch_size=batch_size, random=False, roundup=True)
        for train_idx in idx_iterator:

            if z_init is not None:
                z_init_use = z_init[train_idx]
            else:
                z_init_use = z_init

            if desired_preds is not None:
                desired_preds_use = desired_preds[train_idx].data
            else:
                desired_preds_use = desired_preds

            CLUE_runner = cls(VAE, BNN, original_x[train_idx], uncertainty_weight, aleatoric_weight, epistemic_weight, prior_weight, distance_weight,
                              latent_L2_weight, prediction_similarity_weight, lr, cond_mask=cond_mask, distance_metric=distance_metric,
                              z_init=z_init_use, norm_MNIST=norm_MNIST, desired_preds=desired_preds_use,
                              flatten_BNN=flatten_BNN, regression=regression, prob_BNN=prob_BNN, cuda=cuda)

            z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = \
                CLUE_runner.optimise(min_steps=min_steps, max_steps=max_steps, n_early_stop=n_early_stop)

            full_x_vec.append(x_vec)
            full_z_vec.append(z_vec)
            full_uncertainty_vec.append(uncertainty_vec)
            full_aleatoric_vec.append(aleatoric_vec)
            full_epistemic_vec.append(epistemic_vec)
            full_dist_vec.append(dist_vec)
            full_cost_vec.append(cost_vec)

        full_x_vec = np.concatenate(full_x_vec, axis=1)
        full_z_vec = np.concatenate(full_z_vec, axis=1)
        full_cost_vec = np.concatenate(full_cost_vec, axis=1)
        full_dist_vec = np.concatenate(full_dist_vec, axis=1)
        full_uncertainty_vec = np.concatenate(full_uncertainty_vec, axis=1)
        full_aleatoric_vec = np.concatenate(full_aleatoric_vec, axis=1)
        full_epistemic_vec = np.concatenate(full_epistemic_vec, axis=1)

        return full_x_vec, full_z_vec, full_uncertainty_vec, full_aleatoric_vec, full_epistemic_vec, full_dist_vec, full_cost_vec


class conditional_CLUE(CLUE):

    def __init__(self, VAEAC, BNN, original_x, uncertainty_weight, aleatoric_weight, epistemic_weight, prior_weight, distance_weight,
                 lr, cond_mask=None, distance_metric=None, z_init=None, norm_MNIST=False, flatten_BNN=False,
                 regression=False, cuda=True):

        super(conditional_CLUE, self).__init__(VAEAC, BNN, original_x, uncertainty_weight, aleatoric_weight, epistemic_weight,
                                               prior_weight, distance_weight,
                                               lr, cond_mask, distance_metric, z_init, norm_MNIST, flatten_BNN,
                                               regression, cuda)
        self.cond_mask = cond_mask.type(original_x.type())
        self.VAEAC = VAEAC
        self.prior_weight = 0

    def uncertainty_from_z(self):

        x = self.VAEAC.regenerate(self.z, grad=True)
        x = x * self.cond_mask + self.original_x * (1 - self.cond_mask)

        if self.flatten_BNN:
            to_BNN = x.view(x.shape[0], -1)
        else:
            to_BNN = x

        if self.norm_MNIST:
            to_BNN = calculate_mnist_mean_std_norm(to_BNN)

        if self.regression:
            mu_vec, std_vec = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
            total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty = decompose_std_gauss(mu_vec, std_vec)
        else:
            probs = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
            total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty = decompose_entropy_cat(probs)

        return total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x

    def optimise(self, min_steps=3, max_steps=25,
                 n_early_stop=3):
        # Vectors to capture changes for this minibatch
        z_vec = [self.z.data.cpu().numpy()]
        x_vec = []
        uncertainty_vec = np.zeros((max_steps, self.z.shape[0]))
        aleatoric_vec = np.zeros((max_steps, self.z.shape[0]))
        epistemic_vec = np.zeros((max_steps, self.z.shape[0]))
        dist_vec = np.zeros((max_steps, self.z.shape[0]))
        cost_vec = np.zeros((max_steps, self.z.shape[0]))  # this one doesnt consider the prior

        it_mask = np.zeros(self.z.shape[0])

        for step_idx in range(max_steps):
            self.optimizer.zero_grad()
            total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x = self.uncertainty_from_z()
            objective, w_dist = self.get_objective(x, total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty)
            objective.mean(dim=0).backward()  # backpropagate

            self.optimizer.step()

            # save vectors
            uncertainty_vec[step_idx, :] = total_uncertainty.data.cpu().numpy()
            aleatoric_vec[step_idx, :] = aleatoric_uncertainty.data.cpu().numpy()
            epistemic_vec[step_idx, :] = epistemic_uncertainty.data.cpu().numpy()
            dist_vec[step_idx, :] = (w_dist.data.cpu().numpy())
            cost_vec[step_idx, :] = (objective.data.cpu().numpy())
            x_vec.append(x.data)  # we dont convert to numpy yet because we need x0 for L1
            z_vec.append(self.z.data.cpu().numpy())  # this one is after gradient update while x is before

            it_mask = CLUE.update_stopvec(cost_vec, it_mask, step_idx, n_early_stop, min_steps)

        #  Generate final (or resulting s sample)

        x = self.VAE.regenerate(self.z, grad=False).data
        x = x * self.cond_mask + self.original_x * (1 - self.cond_mask)
        x_vec.append(x)
        x_vec = [i.cpu().numpy() for i in x_vec]  # convert x to numpy
        x_vec = np.stack(x_vec)
        z_vec = np.stack(z_vec)

        # Recover correct indexes using mask
        uncertainty_vec, epistemic_vec, aleatoric_vec, dist_vec, cost_vec, z_vec, x_vec = \
            CLUE.apply_stopvec(it_mask, uncertainty_vec, epistemic_vec, aleatoric_vec, dist_vec, cost_vec, z_vec,
                               x_vec, n_early_stop)
        return z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec


def decompose_std_gauss(mu, sigma, sum_dims=True):
    aleatoric_var = (sigma ** 2).mean(dim=0)
    epistemic_var = ((mu ** 2).mean(dim=0) - mu.mean(dim=0) ** 2)
    total_var = aleatoric_var + epistemic_var
    if sum_dims:
        aleatoric_var = aleatoric_var.sum(dim=1)
        epistemic_var = epistemic_var.sum(dim=1)
        total_var = total_var.sum(dim=1)
    return total_var.sqrt(), aleatoric_var.sqrt(), epistemic_var.sqrt()


def decompose_entropy_cat(probs, eps=1e-10):
    posterior_preds = probs.mean(dim=0, keepdim=False)
    total_entropy = -(posterior_preds * torch.log(posterior_preds + eps)).sum(dim=1, keepdim=False)

    sample_preds_entropy = -(probs * torch.log(probs + eps)).sum(dim=2, keepdim=False)
    aleatoric_entropy = sample_preds_entropy.mean(dim=0, keepdim=False)

    epistemic_entropy = total_entropy - aleatoric_entropy

    # returns (batch_size)
    return total_entropy, aleatoric_entropy, epistemic_entropy

gaussian_mlp.py

class GaussianMLP(nn.Module):
    def __init__(self, input_dim, width, depth, output_dim, flatten_image):
        super(GaussianMLP, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.width = width
        self.depth = depth
        self.flatten_image = flatten_image

        layers = [nn.Linear(input_dim, width), nn.ReLU()]
        for i in range(depth - 1):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(width, 2 * output_dim))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        if self.flatten_image:
            x = x.view(-1, self.input_dim)
        x = self.block(x)
        mu = x[:, :self.output_dim]
        sigma = F.softplus(x[:, self.output_dim:])
        return mu, sigma

gaussian_bnn.py

class GaussianBNN(BaseNet):
    def __init__(self, model, N_train, lr=1e-2, cuda=True, eps=1e-3, grad_std_mul=20):
        super(GaussianBNN, self).__init__()
        # removed constructor for minimalizm
        self._create_network()
        self._create_optimizer()
        self.schedule = None  # [] #[50,200,400,600]
        self.epoch = 0

        self.grad_buff = []
        self.grad_std_mul = grad_std_mul
        self.max_grad = 1e20
        self.eps = eps

        self.weight_set_samples = []

    def _create_network(self):
        torch.manual_seed(42)
        if self.cuda:
            torch.cuda.manual_seed(42)

        if self.cuda:
            self.model.cuda()
            cudnn.benchmark = True

        print('Total params: %.2fM' % (self.get_nb_parameters() / 1000000.0))

    def _create_optimizer(self):
        self.optimizer = StochasticHamiltonMonteCarloSampler(params=self.model.parameters(), lr=self.lr, base_C=0.05, gauss_sig=0.1)

    def fit(self, x, y, burn_in=False, resample_momentum=False, resample_prior=False):
        self.set_model_mode(train=True)
        x, y = variable_to_tensor_list(variables=(x, y), cuda=self.cuda)

        self.optimizer.zero_grad()
        mu, sigma = self.model(x)
        sigma = sigma.clamp(min=self.eps)
        loss = -diagonal_gauss_loglike(y, mu, sigma).mean(dim=0) * self.N_train

        loss.backward()

        if len(self.grad_buff) > 100:
            self.max_grad = np.mean(self.grad_buff) + self.grad_std_mul * np.std(self.grad_buff)
            self.grad_buff.pop(0)

        self.grad_buff.append(nn.utils.clip_grad_norm_(parameters=self.model.parameters(),
                                                       max_norm=self.max_grad, norm_type=2))
        if self.grad_buff[-1] >= self.max_grad:
            print(self.max_grad, self.grad_buff[-1])
            self.grad_buff.pop()

        self.optimizer.step(burn_in=burn_in, resample_momentum=resample_momentum, resample_prior=resample_prior)

        return loss.data * x.shape[0] / self.N_train, mu.data, sigma.data

    def eval(self, x, y):
        self.set_model_mode(train=False)
        x, y = variable_to_tensor_list(variables=(x, y), cuda=self.cuda)
        mu, sigma = self.model(x)
        sigma = sigma.clamp(min=self.eps)
        loss = -diagonal_gauss_loglike(y, mu, sigma).mean(dim=0) * self.N_train

        return loss.data * x.shape[0] / self.N_train, mu.data, sigma.data

    @staticmethod
    def unnormalised_eval(pred_mu, pred_std, y, y_mu, y_std, gmm=False):
        ll = gaussian_mixture_model_loglike(pred_mu, pred_std, y, y_mu, y_std, gmm=gmm)  # this already computes sum
        if gmm:
            pred_mu = pred_mu.mean(dim=0)
        rms = get_root_mean_square(pred_mu, y, y_mu, y_std)  # this already computes sum
        return rms, ll

    def predict(self, x):
        self.set_model_mode(train=False)
        x, = variable_to_tensor_list(variables=(x,), cuda=self.cuda)
        mu, sigma = self.model(x)
        return mu.data, sigma.data

    def save_sampled_net(self, max_samples):

        if len(self.weight_set_samples) >= max_samples:
            self.weight_set_samples.pop(0)

        self.weight_set_samples.append(copy.deepcopy(self.model.state_dict()))

        logger.warning(f"Saving Samples: {len(self.weight_set_samples)} for Max Samples: {max_samples}")
        logger.warning(f"Samples: {self.weight_set_samples}")

        return None

    def sample_predict(self, x, num_samples, grad=False):
        self.set_model_mode(train=False)
        if num_samples == 0:
            num_samples = len(self.weight_set_samples)

        x, = variable_to_tensor_list(variables=(x,), cuda=self.cuda)

        if grad:
            self.optimizer.zero_grad()
            if not x.requires_grad:
                x.requires_grad = True

        mu_vec = x.data.new(num_samples, x.shape[0], self.model.output_dim)
        std_vec = x.data.new(num_samples, x.shape[0], self.model.output_dim)

        # iterate over all saved weight configuration samples
        for idx, weight_dict in enumerate(self.weight_set_samples):
            if idx == num_samples:
                break
            self.model.load_state_dict(weight_dict)
            # mu_vec[idx], std_vec[idx] = self.model(x)
            mu, std = self.model(x)
            # mu_vec, std_vec = mu_vec.copy(), std_vec.copy()
            mu_vec[idx], std_vec[idx] = mu, std

        if grad:
            return mu_vec[:idx], std_vec[:idx]
        else:
            return mu_vec[:idx].data, std_vec[:idx].data

    def get_weight_samples(self, Nsamples=0):
        weight_vec = []

        if Nsamples == 0 or Nsamples > len(self.weight_set_samples):
            Nsamples = len(self.weight_set_samples)

        for idx, state_dict in enumerate(self.weight_set_samples):
            if idx == Nsamples:
                break

            for key in state_dict.keys():
                if 'weight' in key:
                    weight_mtx = state_dict[key].cpu().data
                    for weight in weight_mtx.view(-1):
                        weight_vec.append(weight)

        return np.array(weight_vec)

    def save_weights(self, filename):
        save_object(self.weight_set_samples, filename)

    def load_weights(self, filename):
        self.weight_set_samples = load_object(filename)

Full Traceback of the Error Message:

/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/autograd/__init__.py:266: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 2235, in <module>
    main()
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 2217, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1527, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1534, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 210, in <module>
    main()
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 189, in main
    z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = CLUE_explainer.optimise(
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 191, in optimise
    total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds = self.uncertainty_from_z()
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 124, in uncertainty_from_z
    mu_vec, std_vec = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/utils/clue/bnn/gaussian_bnn.py", line 134, in sample_predict
    mu, std = self.model(x)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/utils/clue/gaussian_mlp.py", line 27, in forward
    x = self.block(x)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:118.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1534, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Users/lukasscholz/Applications/PyCharm Professional Edition.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 210, in <module>
    main()
  File "/Users/lukasscholz/repositorys/counterfactuals/examples/test.py", line 189, in main
    z_vec, x_vec, uncertainty_vec, epistemic_vec, aleatoric_vec, cost_vec, dist_vec = CLUE_explainer.optimise(
  File "/Users/lukasscholz/repositorys/counterfactuals/counterfactual_xai/methods/clue.py", line 194, in optimise
    objective.sum(dim=0).backward()  # backpropagate
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/Users/lukasscholz/repositorys/counterfactuals/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [200, 2]], which is output 0 of AsStridedBackward0, is at version 1401; expected version 1400 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Hi Lukas!

First, try using the inplace-modification-error debugging suggestions in the
following post:

Note that optimizer.step() counts as an inplace modification for the parameters
being optimized. Parts of your post have the flavor of optimizer.step() being
the cause of the inplace modification.

This error message is telling you that the tensor that was modified inplace has
shape [200, 2], and it appears to be a Linear. Does GaussianMLP.block
contain a Linear (200, 2), perhaps the last Linear in the Sequential?

(Note, autograd raises the error on the first inplace modification it detects, so,
if this is what’s going on, earlier Linears in block probably also have been
modified inplace.)

Anyway, given the error message you posted, start by locating any tensors of
shape [200, 2] you have (call them generically t) and print out t._version
at various places in your code to use a divide-and-conquer strategy (binary
search) to locate where t._version changes from 1400 to 1401. This is
where the relevant inplace modification occurs.

If optimizer.step() is the cause, then ask yourself if you are ever calling
.backward() on a quantity that depends on some “old” data, that is, that
depends (in the sense of a still-live autograd computation graph) on some
data that was computed using some model Parameter (perhaps that Linear)
prior to the most recent optimizer.step() call.

(As an aside, .data has been deprecated for public-facing use. Using it can
confuse autograd and cause errors. I didn’t notice any .data calls in your
code that looked like they would be causing trouble, but you should rewrite
your code to eliminate them anyway, if only because the semantics of .data
could silently change in future versions.)

Good luck!

K. Frank

Thanks for the reply @KFrank !
I searched the last two days for the [200, 2] Tensor you mentioned, but had no luck in finding it.
Since the original CLUE project is implemented in PyTorch 1.3.1 and Python 2.7 I tried to get it running with those dependencies and everything worked fine with the same code. My code base is build upon python>=3.8 and torch 2.1.0

Do you maybe have an intuition why the inplace change of the tensor is happening with newer pytorch and python versions?

Hi Lukas!

Keep looking for it. Finding the tensor that is being modified inplace opens up
probably the best avenue for debugging the issue. (I’ve never seen autograd
“hallucinate” the shape of a tensor it flags as having been modified inplace.
So I really bet you have such a tensor in your model.)

Note that a Linear (200, 2) has a weight of shape [2, 200] and autograd
will typically report its transpose (with shape [200, 2]) as being modified
inplace. So if you have a Linear.weight with shape [2, 200], that would
likely be your culprit.

I would recommend that you print out the shapes of the inputs, outputs and
weights of the Linears that make up GaussianMLP.block. The forward-call
traceback you posted strongly suggests that one of the tensors in block is
the one that gets modified inplace.

Note, the “version 1401” in your error message is potentially meaningful. This
suggests to me that the tensor in question is one being optimized, that you’ve
called optimizer.step() 1399 times (modifying the tensor inpace) without
error, but after your 1400th call, the error, for whatever reason, shows up. (For
example, do you have 1400 batches in an epoch, but then do something a bit
different at the end of an epoch (or at the beginning of the next epoch?)

Also, if you call into any third-party code, you might need to “instrument” that
code as well (that is, add code to it that prints out the shapes of various tensors).

I don’t see any obvious explanation for this. 1.3.1 is pretty old by now. It’s
possible that a bug in the old version somehow let the inplace-modification
error slip by, but the newer version catches it.

Also (although, again, I don’t see any issues with your use of .data), a
backward-breaking change in the semantics of .data could hypothetically
make the two versions behave differently. (In any event, if only as a matter
of good housekeeping, you should get rid of the use of .data in your code.)

Best.

K. Frank

@KFrank I found the tensor and could trace back it’s origin, but had no luck with using clone() or removing inplace operators like add_(). Here is the way the tensors are created:

  1. total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds = self.uncertainty_from_z() creats the tensor x
  2. Inside uncertainty_from_z() the tensor x is created like this x = self.VAE.regenerate(self.z, grad=True), there z is a tensor of shape [512, 16] which holds data from my valset, that is created earlier, output shape is [512, 12]
  3. Then tensor x is fed into my GaussianBNN mu_vec, std_vec = self.BNN.sample_predict(x, num_samples=0, grad=True) the function sample_predict() looks like this:
    def sample_predict(self, x, num_samples, grad=False):
        self.set_model_mode(train=False)
        if num_samples == 0:
            num_samples = len(self.weight_set_samples)
        x, = variable_to_tensor_list(variables=(x,), cuda=self.cuda) # torch.Size([512, 12])
        if grad:
            self.optimizer.zero_grad()
            if not x.requires_grad:
                x.requires_grad = True
        mu_vec = x.new(num_samples, x.shape[0], self.model.output_dim) # torch.Size([100, 512, 1])
        std_vec = x.new(num_samples, x.shape[0], self.model.output_dim) # torch.Size([100, 512, 1])
        for idx, weight_dict in enumerate(self.weight_set_samples): # 100 iterations
            if idx == num_samples:
                break
            self.model.load_state_dict(weight_dict)
            mu, std = self.model(x.clone()) # x = torch.Size([512, 12])
            mu_vec[idx] = mu.clone() # torch.Size([99, 512, 1])
            std_vec[idx] = std.clone() # torch.Size([99, 512, 1])

        if grad:
            return mu_vec[:idx].clone(), std_vec[:idx].clone() # both torch.Size([99, 512, 1])
        else:
            return mu_vec[:idx].clone(), std_vec[:idx].clone()

Note: The for loop in sample_predict() is being called before that 13-times, the 14th time is from the uncertainty_from_z()-function for 100 iterations. All other 13 calls also iterated 100 times. So the 14th iteration would be the create/modify the version 1400 of the tensor from the error traceback.

  1. self.model(x.clone()) calls the GaussianMLP you mentioned. There the forward pass is called. This is executed 99 times instead of 100, therefore the mu_vec and std_vec are torch.Size([99, 512, 1]) instead of torch.Size([100, 512, 1]).
    There the forward pass looks like this:
    def forward(self, x):
        if self.flatten_image:
            x = x.view(-1, self.input_dim)


        for i, layer in enumerate(self.block):
            x = layer(x) # input is torch.Size([512, 12])
            # after last layer output size is torch.Size([512, 2]
            # ...and weight.shape is torch.Size([2, 200]) OUR TENSOR IN SEARCH 
            if isinstance(layer, nn.Linear):
                print("Weight shape for layer {}: {}".format(i, layer.weight.shape))

        mu = x[:, :self.output_dim]
        sigma = F.softplus(x[:, self.output_dim:])
        return mu, sigma
  1. After going through the forward pass and the whole GaussianBNN the function uncertainty_from_z() takes the mu_vecand std_vec, which should have one more dimension and computes an uncertainty metric based on it:
                mu_vec, std_vec = self.BNN.sample_predict(to_BNN, num_samples=0, grad=True)
                total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty = decompose_std_gauss(mu_vec, std_vec)
                preds = mu_vec.mean(dim=0)
                print(f"PREDS: {preds.shape}")
            
        else:
            if self.regression:
                mu, std = self.BNN.predict(to_BNN, grad=True)
                total_uncertainty = std.squeeze(1)
                aleatoric_uncertainty = total_uncertainty
                epistemic_uncertainty = total_uncertainty * 0
                preds = mu

        return total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, x, preds
  1. The final step before the backward pass objective.sum(dim=0).backward() is getting the objective with objective, w_dist = self.get_objective(x, total_uncertainty, aleatoric_uncertainty, epistemic_uncertainty, preds) The function get_objective(). The objective and w_dist both have the shape torch.Size([512])

  2. After that objective.sum(dim=0).backward() happens and I get the error.

Sorry for all the text, but I’m really clueless how I can debug this problem, since I tried to .clone() basically every tensor and eliminate all inplace operations like add_() or inplace modifications like:

mu_vec[idx], std_vec[idx] = self.model(x)

Is the missing dim in mu_vec and std_vec the root case of the issue? How can I further debug this?

Hi Lukas!

You have a lot of things going on here …

First, to confirm, does your error always show up with “expected version 1400,”
or does the error’s “expected version” vary somewhat?

Just to be clear, your set_model_mode (train = False) is probably only turning
off things like Dropout and not turning off autograd tracking.

Note that model.load_state_dict() modifies the model’s parameters inplace.
So you should track this as the possible cause of your error.

Just based on this numerology, this seems likely to be related to you error.

What is being done on the 1401st “iteration” that is different than the first 1400?

This looks correct. Now you need to determining where this tensor is being
modified for the 1401st time.

From outside GaussianMLP.forward() this tensor can presumably be accessed
as something like model.block[-1].weight. (Note, the other Linear’s in
block probably also have the same issue. It’s just that autograd stops looking
after the first error is detected.)

So place print (model.block[-1].weight._version) at various places
in your code, especially just before and after your calls to

self.model.load_state_dict(weight_dict)

and

self.optimizer.step()

It’s not that the call that modifies block[-1].weight inplace is necessarily in
error. Both loading the state dict and calling the optimizer are logically sensible
things to do.

Let me call old data data that was computed using the model prior to the
most recent modification to block[-1].weight. You are almost certainly
calling .backward() on some quantity that depends on some old data.

If this is the case, why is it happening, and why doesn’t it happen until the
1401st iteration? What difference in the code path on the 1401st iteration
is causing that old data to slip in?

If this is the cause of your error, in can be “fixed” by .detach()ing the old
data before using it in backward. Something like:

objective = some_new_objective + old_data.detach()
objective.sum(dim=0).backward()

Because of the .detach(), this code will not backpropagate through old_data
and the inplace-modification error won’t occur.

But this might not be logically correct. It’s up to the details of your use case
whether the gradients computed by objective.sum (dim = 0).backward()
should include the gradients due to old_data.

As you note, mu_vec[idx], std_vec[idx] = self.model(x) does modify
mu_vec and std_vec inplace, but this would seem not to be the (immediate)
cause of your issue because it isn’t modifying model.block[-1].weight
inplace.

To summarize:

Determine exactly which line of code (in which iterations of any enclosing loops)
is causing _version to be incremented to 1401. The next call to .backward()
(on outputs of the relevant model) should be the one that raises the error.

Think hard about what’s different about that 1401st iteration.

Look hard at objective.sum (dim = 0), the quantity on which you are calling
.backward(). Is there any way some old data can be creeping into objective?
If there is, think logically about why the old data is there. Can it be left out? Can
it be .detached()?

If you logically need gradients from the old data, you may need to modify the
forward pass of model so that autograd can hold on to a copy of the tensor
being modified inplace. (This can certainly be done, but you would need to
perform some surgery on your model.)

Good luck!

K. Frank