Pre-calculated distance matrix in Dataset

I’m working on a graph neural network involving 3d positional vectors R as nodes and the scaler distances D as edges, and I want to take derivatives of the network output with respect to the initial position R. Here is a minimum working example not involving any neural network.

import torch

R = torch.rand(10, 5, 3, requires_grad=True)  # Position vectors, (n_data, n_point, 3)
D = torch.norm(R.unsqueeze(2) - R.unsqueeze(1), dim=-1)  # Distances matrix, (n_data, n_point, n_point)
gradient = torch.autograd.grad(D, R, grad_outputs=torch.ones_like(D))[0]
print(gradient)  # Returns tensor with shape (10, 5, 3)

Since the distance matrix D does not change throughout the training, I want to pre-calculate the distance matrix when I create the PyTorch Dataset, rather than to re-calculate it every time I load a batch. Somehow the graph structure in PyTorch doesn’t allow me to do that, presumably because the batched R of the Dataset output is different from the original R used in the graph of D.

import torch
from torch.utils.data import Dataset

class BatchDataset(Dataset):
    def __init__(self, R):
        self.R = torch.tensor(R, requires_grad=True)  # Position vectors, (n_data, n_point, 3)
        self.D = torch.norm(self.R.unsqueeze(2) - self.R.unsqueeze(1), dim=-1)  # Distances matrix, (n_data, n_point, n_point)
    def __getitem__(self, index):
        return self.R[index], self.D[index]
    def __len__(self):
        return self.R.size(0)

R = torch.rand(10, 5, 3, requires_grad=True)  # Position vectors, (n_data, n_point, 3)
batch_R, batch_D = BatchDataset(R)[:]
gradient = torch.autograd.grad(batch_D, batch_R, grad_outputs=torch.ones_like(batch_D))[0]
print(gradient)  # Returns RuntimeError: One of the differentiated Tensors appears to not have been used in the graph.
gradient = torch.autograd.grad(batch_D, batch_R, grad_outputs=torch.ones_like(batch_D), allow_unused=True)[0]
print(gradient)  # Returns None

Is there a way to pre-calculate the distance matrix in a Dataset while preserving the graph structure for gradient calculation? Or the distance matrix and the edge features have to be re-calculated if the gradients are needed? Thanks.