LeNet5 & Self-Organizing Maps - RuntimeError: Trying to backward through the graph a second time - PyTorch

I have a LeNet-5 CNN training with a Self-Organizing Map trained on MNIST data. The training code (for brevity) is:

# SOM (flattened) weights-
# m = 40, n = 40, n = 84 (LeNet's output shape/dim)
centroids = torch.randn(m * n, dim, device = device, dtype = torch.float32)

locs = [np.array([i, j]) for i in range(m) for j in range(n)]
locations = torch.LongTensor(np.asarray(locs)).to(device)
del locs

def get_bmu_distance_squares(bmu_loc):
    bmu_distance_squares = torch.sum(
        input = torch.square(locations.float() - bmu_loc),
        dim = 1
    )
    return bmu_distance_squares

distance_mat = torch.stack([get_bmu_distance_squares(loc) for loc in locations])

centroids = centroids.to(device)

num_epochs = 50
qe_train = list()

step = 1


for epoch in range(1, num_epochs + 1):
    qe_epoch = 0.0
    for x, y in train_loader:
        x = x.to(device)
        z = model(x)

        # SOM training code:

        batch_size = len(z)

        # Compute distances from batch to (all SOM units) centroids-
        dists = torch.cdist(x1 = z, x2 = centroids, p = p_norm)

        # Find closest (BMU) and retrieve the gaussian correlation matrix
        # for each point in the batch
        # bmu_loc is BS, num points-
        mindist, bmu_index = torch.min(dists, -1)
        # print(f"quantization error = {mindist.mean():.4f}")

        bmu_loc = locations[bmu_index]


        # Compute the SOM weight update:

        # Update LR
        # It is a matrix of shape (BS, centroids) or, (BS, mxn) and tells
        # for each input how much it will affect each (SOM unit) centroid-
        bmu_distance_squares = distance_mat[bmu_index]

        # Get current lr and neighbourhood radius for current step-
        decay_val = scheduler(it = step, tot =  int(len(train_loader) * num_epochs))
        curr_alpha = (alpha * decay_val).to(device)
        curr_sigma = (sigma * decay_val).to(device)

        # Compute Gaussian neighbourhood function-
        neighborhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, ((2 * torch.square(curr_sigma)) + 1e-5))))

        expanded_z = z.unsqueeze(dim = 1).expand(-1, grid_size, -1)
        expanded_weights = centroids.unsqueeze(0).expand((batch_size, -1, -1))

        delta = expanded_z - expanded_weights
        lr_multiplier = curr_alpha * neighborhood_func

        delta.mul_(lr_multiplier.reshape(*lr_multiplier.size(), 1).expand_as(delta))
        delta = torch.mean(delta, dim = 0)
        new_weights = torch.add(centroids, delta)
        centroids = new_weights

        # return bmu_loc, torch.mean(mindist)

        # Compute quantization error los-
        qe_loss = torch.mean(mindist)
        qe_epoch += qe_loss.item()

        # Empty accumulated gradients-
        optimizer.zero_grad()

        # Perform backprop-
        qe_loss.backward()

        # Update model trainable params-
        optimizer.step()
         
        step += 1


    qe_train.append(qe_epoch / len(train_loader))
    print(f"\nepoch = {epoch}, QE = {qe_epoch / len(train_loader):.4f}"
        f" & SOM wts L2-norm = {torch.norm(input = centroids, p = 2).item():.4f}"
    )

On trying to execute this code, I get the error:

line 252: qe_loss.backward()

Traceback (most recent call last): File
“c:\some_dir\som_lenet5.py”,
line 252, in
qe_loss.backward() File “c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch_tensor.py”,
line 522, in backward
torch.autograd.backward( File “c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\autograd_init_.py”,
line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward
through the graph a second time (or directly access saved tensors
after they have already been freed). Saved intermediate values of the
graph are freed when you call .backward() or autograd.grad(). Specify
retain_graph=True if you need to backward through the graph a second
time or if you need to access saved tensors after calling backward.