One of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]

Hello!

I’m running into the following error (with anomaly detection on):

UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/spawn.py", line 129, in _main
    return self._bootstrap(parent_sentinel)
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/apunt/repos/lhasa/FLuID_POC/attempt4.py", line 167, in train_and_evaluate_self_learning
    logits, loss = ff_nn(tensor_X, tensor_Y)
  File "/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/apunt/repos/lhasa/FLuID_POC/attempt4.py", line 35, in forward
    out = self.fc2(out)
  File "/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Process Process-4:
Traceback (most recent call last):
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/anaconda3/envs/opig/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/apunt/repos/lhasa/FLuID_POC/attempt4.py", line 220, in train_and_evaluate_self_learning
    loss.backward()
  File "/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/opt/anaconda3/envs/opig/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of AsStridedBackward0, is at version 6251; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

For context, I have a binary classifier (BinaryClassificationNN) that works when I run it normally, however I’m trying to use it for iterative self-learning and it breaks when I try the backwards step (loss.backward())

I’ve read up on this error message, and it appears to be because of some in-place operation on a tensor. I’ve gone through my code, and I’ve seen no signs of in-place operations (I was looking for instances of += and functions like add_.()), but I may very well be missing something. It looks like there’s something that’s modifying the tensors in my forward step, but I can’t figure out what it is.

Let me know if anyone has any suggestions!

Here’s the relevant code:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pickle
from multiprocessing import Manager, Process

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef

import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import PandasTools

import torch
from torch.utils import data 
from torch import nn
from torch.nn import functional as F

import ss_utils

# Define the binary classification neural network
class BinaryClassificationNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(BinaryClassificationNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x, y_hat=None):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        
        if y_hat is not None:
            # Use y_hat (corrected labels) during training in the post-prototype phase
            loss = F.binary_cross_entropy_with_logits(out, y_hat)
            return out, loss
        else:
            return out

# Feed the network fingerprint tensors
def transform_ff_nn(fp) -> torch.Tensor:
    fp_np = np.array(fp.tolist(), dtype=np.float32)
    tensor_fp = torch.tensor(fp_np, dtype=torch.float32)
    return tensor_fp

def calculate_density(similarity_matrix, threshold):
    # Calculate density based on the similarity matrix
    density = np.sum(similarity_matrix > threshold, axis=1)
    print(density)
    return density

def select_prototypes_with_similarity(similarity_matrix, density, num_prototypes):
    # Sort density values in descending order
    sorted_indices = np.argsort(density)[::-1]
    sorted_density = density[sorted_indices]

    prototypes = []

    for i in range(num_prototypes):
        selected_index = sorted_indices[i]
        prototypes.append(selected_index)

    return prototypes


def load_data_and_label_table(data_shared, label_table_shared):
    # Load data (load it once outside of multiprocessing)
    testSet = pd.read_pickle(os.path.join("data", "hERG_lhasa_test.pkl"))
    data_shared['testSet'] = testSet
    print("Original train/test sets loaded")

    # Load federated dataset and convert classification to numeric value (keep it once outside of multiprocessing)
    with open('label_table_tanimoto.pickle', 'rb') as f:
        label_table = pickle.load(f)
    # label_table = label_table[:1000]
    label_table['C-F8'] = label_table['C-F8'].map({'Inactive': 0, 'Active': 1})
    label_table_shared['label_table'] = label_table
    print("Label table loaded and corrected")

    # Filter out rows with invalid SMILES
    label_table = label_table.dropna(subset=['MOLECULE'])

    # Parse SMILES strings and handle errors (if needed)
    # label_table['SMILES'] = label_table['MOLECULE'].apply(lambda mol: Chem.MolToMolBlock(mol) if mol is not None else None)

    # Split the label_table
    x_train, x_val, y_train, y_val = train_test_split(label_table['FP'], label_table['C-F8'], test_size=0.2)
    train_df = pd.DataFrame({'smiles': x_train, 'binary_classification_label': y_train})
    val_df = pd.DataFrame({'smiles': x_val, 'binary_classification_label': y_val})
    data_shared['train_df'] = train_df
    data_shared['val_df'] = val_df
    data_shared['y_val'] = y_val

def train_and_evaluate_basic(data_shared, label_table_shared):
    data = data_shared['testSet']
    label_table = label_table_shared['label_table']
    train_df = data_shared['train_df']
    val_df = data_shared['val_df']
    y_val = data_shared['y_val']

    # Modify the neural network architecture
    ff_nn = BinaryClassificationNN(2048, 128)

    # Train and evaluate the binary classification model
    # Assuming you have train_df and val_df defined
    out = ss_utils.train_neural_network(train_df=train_df, val_df=val_df, 
                                        smiles_col="smiles", regression_column="binary_classification_label",
                                        transform=transform_ff_nn, 
                                        neural_network=ff_nn)

    # Extract predictions and targets from out
    val_predictions = out['val_predictions']

    accuracy = accuracy_score(y_val, (val_predictions > 0.5).astype(int))  # Assuming a threshold of 0.5 for binary classification
    precision = precision_score(y_val, (val_predictions > 0.5).astype(int))
    recall = recall_score(y_val, (val_predictions > 0.5).astype(int))
    f1 = f1_score(y_val, (val_predictions > 0.5).astype(int))
    mcc = matthews_corrcoef(y_val, (val_predictions > 0.5).astype(int))

    # Print the metrics
    print(f'Accuracy: {accuracy}')
    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'F1 Score: {f1}')
    print(f'Balanced MCC: {mcc}')

def train_and_evaluate_self_learning(data_shared, label_table_shared):
    torch.autograd.set_detect_anomaly(True)
    
    data = data_shared['testSet']
    label_table = label_table_shared['label_table']
    train_df = data_shared['train_df']
    val_df = data_shared['val_df']
    y_val = data_shared['y_val']
    learning_rate = 0.002 # TODO: decrease by 10 every 5 epochs (?i)

    similarity_matrix = np.load('similarity_matrix.npy')

    # Modify the neural network architecture
    ff_nn = BinaryClassificationNN(2048, 128)
    optimizer = torch.optim.Adam(ff_nn.parameters(), lr=learning_rate)

    # Training parameters
    num_epochs = 10
    start_epoch = 5  # Adjust this to the desired start epoch for post-prototype phase
    alpha = 0.5  # Adjust this for the self-learning trade-off
    for epoch in range(num_epochs):
        tensor_X = transform_ff_nn(train_df['smiles'])
        print(f'tensor_X shape: {tensor_X.shape}')
        print(f'tensor_X data type: {tensor_X.dtype}')
        
        if epoch < start_epoch:
            # Pre-prototype phase
            X, Y = train_df['smiles'], train_df['binary_classification_label']
            tensor_Y_wrong_shape = torch.tensor(Y.values, dtype=torch.float32)
            # Reshape tensor_Y to match the output shape of your neural network
            expected_input_dim = ff_nn.fc1.in_features  # Get input_dim from the first layer of your network
            print(f'expected_input_dim in BinaryClassificationNN: {expected_input_dim}')
            tensor_Y = tensor_Y_wrong_shape.view(-1, 1)  # Reshape tensor_Y, adjust this if needed
            print(f'tensor_Y shape: {tensor_Y.shape}')
            print(f'tensor_Y data type: {tensor_Y.dtype}')
            print(f'tensor_Y values: {tensor_Y}')
            logits, loss = ff_nn(tensor_X, tensor_Y)
            # Train using original labels
            out = ss_utils.train_neural_network(train_df=train_df, val_df=val_df, 
                                    smiles_col="smiles", regression_column="binary_classification_label",
                                    transform=transform_ff_nn, 
                                    neural_network=ff_nn)
        else:
            # Calculate similarity scores σc for each class
            similarity_scores = []
            density_threshold = 0.01
            for c in range(num_classes):  # Loop over classes
                density = calculate_density(similarity_matrix, density_threshold) 
                num_prototypes = 8
                class_prototypes = select_prototypes_with_similarity(similarity_matrix, density, num_prototypes)  # Select prototypes for class c
                sigma_c = 0.0
                for i, x in enumerate(train_df['smiles']):
                    similarity_to_prototypes = []  # Calculate similarity to prototypes for sample x
                    for prototype in class_prototypes:
                        similarity = calculate_similarity(x, prototype)  # Use your similarity metric here
                        similarity_to_prototypes.append(similarity)
                    sigma_c += max(similarity_to_prototypes)  # Take the maximum similarity
                similarity_scores.append(sigma_c)

            # Assign corrected labels yˆ based on the class with the highest similarity score σc
            corrected_labels = [np.argmax(scores) for scores in similarity_scores]
            tensor_corrected_labels = torch.tensor(corrected_labels, dtype=torch.float32)
            print(f'tensor_corrected_labels shape: {tensor_corrected_labels.shape}')
            print(f'tensor_corrected_labels data type: {tensor_corrected_labels.dtype}')
            print(f'tensor_corrected_labels values: {tensor_corrected_labels}')
            tensor_corrected_labels = tensor_corrected_labels.view(-1, 1)  # Reshape corrected_labels if needed
            logits, loss = ff_nn(tensor_X, tensor_corrected_labels)

            # Continue with training using corrected labels
            out = ss_utils.train_neural_network(train_df=train_df, val_df=val_df, 
                                                smiles_col="smiles", regression_column="binary_classification_label",
                                                transform=transform_ff_nn, 
                                                neural_network=ff_nn,
                                                corrected_labels=corrected_labels)  # Pass corrected labels to your training function

        print(f'Output logits shape: {logits.shape}')

        # Print information before calling loss.backward()
        print(f'Epoch [{epoch + 1}/{num_epochs}]')
        print(f'  Loss: {loss.item()}')
        print(f'  out.shape: {logits.shape}')
        print(f'  tensor_Y.shape: {tensor_Y.shape}')  # Shape of tensor_Y used in loss calculation

        # Debugging: Print the tensors to check for any unusual values
        print(f'  tensor_Y values: {tensor_Y}')
        print(f'  tensor_X values: {tensor_X}')

        with torch.autograd.detect_anomaly():
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')

    # Print the metrics
    print(f'Accuracy: {accuracy}')
    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'F1 Score: {f1}')
    print(f'Balanced MCC: {mcc}')

if __name__ == '__main__':
    # Create shared dictionaries for data and label_table
    data_manager = Manager()
    data_shared = data_manager.dict()
    label_table_manager = Manager()
    label_table_shared = label_table_manager.dict()

    # Load data and label_table in a separate process
    data_process = Process(target=load_data_and_label_table, args=(data_shared, label_table_shared))
    data_process.start()
    data_process.join()

    # Train and evaluate in a separate process
    # Can change to target to train_and_evaluate_basic for testing
    train_process = Process(target=train_and_evaluate_self_learning, args=(data_shared, label_table_shared))
    train_process.start()
    train_process.join()

and ss_utils.py:

"""
This module contains some helper functions for our main notebook
"""

import collections
import itertools
import time
import typing
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from rdkit import Chem
from rdkit.Chem import AllChem

import torch
from torch.utils import data
from torch import nn
from torch import optim
from torch.nn import functional as F
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, RunningAverage

from ignite.contrib.handlers import ProgressBar


@dataclass
class TrainParams:
    batch_size: int = 64
    val_batch_size: int = 64
    learning_rate: float = 1e-3
    num_epochs: int = 10
    device: typing.Optional[str] = 'cpu'  # <-- I have not run this on the GPU yet, so that may need some debugging.


class SmilesRegressionDataset(data.Dataset):
    """
    Dataset that holds SMILES molecule data along with an associated single
    regression target.
    """

    def __init__(self, smiles_list: typing.List[str],
                 regression_target_list: typing.List[float],
                 transform: typing.Optional[typing.Callable] = None):
        """
        :param smiles_list: list of SMILES strings represnting the molecules
        we are regressing on.
        :param regression_target_list: list of targets
        :param transform: an optional transform which will be applied to the
        SMILES string before it is returned.
        """
        self.smiles_list = smiles_list
        self.regression_target_list = regression_target_list
        self.transform = transform

        assert len(self.smiles_list) == len(self.regression_target_list), \
            "Dataset and targets should be the same length!"

    def __getitem__(self, index):
        x, y = self.smiles_list[index], self.regression_target_list[index]
        if self.transform is not None:
            x = self.transform(x)
        y = torch.tensor([y], dtype=torch.float32)
        return x, y

    def __len__(self):
        return len(self.smiles_list)

    @classmethod
    def create_from_df(cls, df: pd.DataFrame, smiles_column: str = 'smiles',
                       regression_column: str = 'y', transform=None):
        """
        convenience method that takes in a Pandas dataframe and turns it
        into an   instance of this class.
        :param df: Dataframe containing the data.
        :param smiles_column: name of column that contains the x data
        :param regression_column: name of the column which contains the
        y data (i.e. targets)
        :param transform: a transform to pass to class's constructor
        """
        # smiles_list = [x.strip() for x in df[smiles_column].tolist()]
        smiles_list = df[smiles_column].tolist()
        # targets = [float(y) for y in df[regression_column].tolist()]
        targets = df[regression_column].tolist()
        return cls(smiles_list, targets, transform)


def train_neural_network(train_df: pd.DataFrame, val_df: pd.DataFrame,
                          smiles_col:str, regression_column:str,
                         transform: typing.Callable,
                         neural_network: nn.Module,
                         corrected_labels: typing.Optional[np.ndarray] = None,
                         params: typing.Optional[TrainParams]=None,
                         collate_func: typing.Optional[typing.Callable]=None):
    """
    Trains a PyTorch NN module on train dataset, validates it each epoch and returns a series of useful metrics
    for further analysis. Note the networks parameters will be changed in place.

    :param train_df: data to use for training.
    :param val_df: data to use for validation.
    :param smiles_col: column name for SMILES data in Dataframe
    :param regression_column: column name for the data we want to regress to.
    :param transform: the transform to apply to the datasets to create new ones suitable for working with neural network
    :param neural_network: the PyTorch nn.Module to train
    :param corrected_labels: an array of corrected labels for training
    :param params: the training params eg number of epochs etc.
    :param collate_func: collate_fn to pass to dataloader constructor. Leave as None to use default.
    """
    if params is None:
        params = TrainParams()

    # Update the train and valid datasets with new parameters
    train_dataset = SmilesRegressionDataset.create_from_df(train_df, smiles_col, regression_column, transform=transform)
    val_dataset = SmilesRegressionDataset.create_from_df(val_df, smiles_col, regression_column, transform=transform)
    print(f"Train dataset is of size {len(train_dataset)} and valid of size {len(val_dataset)}")

    # Put into dataloaders
    train_dataloader = data.DataLoader(train_dataset, params.batch_size, shuffle=True,
                                       collate_fn=collate_func, num_workers=1)
    val_dataloader = data.DataLoader(val_dataset, params.val_batch_size, shuffle=False, collate_fn=collate_func,
                                       num_workers=1)

    # Optimizer
    optimizer = optim.Adam(neural_network.parameters(), lr=params.learning_rate)

    # Work out what device we're going to run on (ie CPU or GPU)
    device = params.device

    # We're going to use PyTorch Ignite to take care of the majority of the training boilerplate for us
    # see https://pytorch.org/ignite/
    # in particular we are going to follow the example
    # https://github.com/pytorch/ignite/blob/53190db227f6dda8980d77fa5351fa3ddcdec6fb/examples/contrib/mnist/mnist_with_tqdm_logger.py
    def prepare_batch(batch, device, non_blocking):
        x, y = batch
        return x.to(device), y.to(device)


    # TODO: what value of alpha should I be using? 
    # Replace the loss function here with your custom loss function
    alpha = 0.5  # You can adjust the alpha value as needed
    if corrected_labels is not None:
        def custom_loss(outputs, targets):
            original_loss = F.binary_cross_entropy_with_logits(outputs, targets)  # You can change this to your specific loss function
            corrected_loss = F.binary_cross_entropy_with_logits(outputs, corrected_labels)
            total_loss = (1 - alpha) * original_loss + alpha * corrected_loss
            return total_loss

        trainer = create_supervised_trainer(neural_network, optimizer, custom_loss, device=device, prepare_batch=prepare_batch)
    else:
        trainer = create_supervised_trainer(neural_network, optimizer, F.binary_cross_entropy_with_logits, device=device, prepare_batch=prepare_batch)

    evaluator = create_supervised_evaluator(neural_network,
                                            metrics={'loss': Loss(F.binary_cross_entropy_with_logits)},
                                            device=device, prepare_batch=prepare_batch)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names='all')

    train_loss_list = []
    val_lost_list = []
    val_times_list = []

    @trainer.on(Events.EPOCH_COMPLETED | Events.STARTED)
    def log_training_results(engine):
        evaluator.run(train_dataloader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        pbar.log_message("Epoch - {}".format(engine.state.epoch))
        pbar.log_message(
            "Training Results - Epoch: {}  Avg loss: {:.2f}"
                .format(engine.state.epoch, loss)
        )
        train_loss_list.append(loss)

    @trainer.on(Events.EPOCH_COMPLETED | Events.STARTED)
    def log_validation_results(engine):
        s_time = time.time()
        evaluator.run(val_dataloader)
        e_time = time.time()
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        pbar.log_message(
            "Validation Results - Epoch: {} Avg loss: {:.2f}"
                .format(engine.state.epoch, loss))

        pbar.n = pbar.last_print_n = 0
        val_lost_list.append(loss)
        val_times_list.append(e_time - s_time)

    # We can now train our network!
    trainer.run(train_dataloader, max_epochs=params.num_epochs)

    # Having trained it wee are now also going to run through the validation set one
    # last time to get the actual predictions
    val_predictions = []
    neural_network.eval()
    for batch in val_dataloader:
        x, _ = batch  # We don't need the original labels
        x = x.to(device)

        print(f"Shape of input tensor (x): {x.shape}")

        if corrected_labels is not None:
            # Use corrected labels if available
            corrected_labels_tensor = torch.tensor(corrected_labels, dtype=torch.float32).to(device)
            y_pred = neural_network(x, corrected_labels_tensor)  # Use corrected labels in the forward pass
        else:
            # Use original labels
            y_pred = neural_network(x)
        print(f"Shape of predicted tensor (y_pred): {y_pred.shape}")
        val_predictions.append(y_pred.cpu().detach().numpy())

    neural_network.train()
    val_predictions = np.concatenate(val_predictions)

    # Create a table of useful metrics (as part of the information we return)
    total_number_params = sum([v.numel() for v in  neural_network.parameters()])
    out_table = [
        ["Num params", f"{total_number_params:.2e}"],
        ["Minimum train loss", f"{np.min(train_loss_list):.3f}"],
        ["Mean validation time", f"{np.mean(val_times_list):.3f}"],
        ["Minimum validation loss", f"{np.min(val_lost_list):.3f}"]
     ]

    # We will create a dictionary of results.
    results = dict(
        train_loss_list=train_loss_list,
        val_lost_list=val_lost_list,
        val_times_list=val_times_list,
        out_table=out_table,
        val_predictions=val_predictions
    )
    return results

Hi Adele!

This is telling you that fc2.weight is the tensor that is being modified inplace.
Note that its ._version has jumped to 6251, indicating that it’s been modified
inplace 6250 times before the call to .backward() that is triggering the error.

You need to think through your training algorithm. As we’ll see below, you compute
a loss, call a training loop (that presumably calls an opt.step() multiple times),
and then call loss.backward(). opt.step() modifies inplace the parameters that
it is optimizing, causing the error.

Does it really make sense for your algorithm to stick a training loop in between
the computation of loss and the call to loss.backward()?

The inplace operation is presumably a call to opt.step() that isn’t shown in
the code you’ve posted, but is presumably buried somewhere in:

    trainer.run(train_dataloader, max_epochs=params.num_epochs)

This creates fc2 with a weight of shape [1, 128]. The inplace-modification error
reports the transpose of this shape, namely, [128, 1].

Note that this comment in the code you posted warns that some “network
parameters will be changed in place,” tipping you off to the possibility of an
inplace-modification error.

We noted above that the error message tells you that fc2.weight was
modified inplace 6250 times. This would makes sense if your inner training
loop calls some opt.step() 6250 times as it iterates.

For some examples that illustrate how to debug inplace-modification errors,
see this post:

Good luck!

K. Frank