Reducing RAM usage when dealing with very large Tensor

I am trying to build a function that does as follows:
“”"
Calculates a 3D tensor with the minimum distance from each pixel to data.

Inputs:
* points: a numpy array of 2D coordinates and point values, normalized to
be in the range [0,1]. The expected shape is [B, P, 2].
* res: the resolution of the output tensor.
Returns:
A res x res square tensor with floating point values corresponding to the
euclidean distance to the closest point in points. The return tensor
shape is [B, res, res].
“”"
However, it should be able to process a resolution of up to 65,536. I have tried the following method, but it is really slow. Other things I have tried and I run out of RAM. Is there a better way to deal with large tensors like this, or perhaps is my actual algorithm too slow?
Here is my code, I am doing chunk by chunk and writing it to disk:

import torch
import numpy as np
import matplotlib.pyplot as plt
import os

def create_points(batch_size, num_points):
  coords = np.random.rand(batch_size, num_points, 2)
  return coords

def min_dist(points, res):

  data_coords = torch.tensor(points, dtype=torch.float32) * (res - 1)

  B, P, _ = data_coords.shape

  chunk_size = 4096

  # Calculate number of chunks
  num_chunks = (res + chunk_size - 1) // chunk_size

  # min_dist_tensor_full = torch.zeros(B, res, res)
  min_dist_tensor_full = np.memmap('large_tensor.dat', dtype='float32', mode='w+', shape=(B, res, res))

  # Initialize or load the final tensor

  for i in range(num_chunks):
    for j in range(num_chunks):
      x_start = i * chunk_size
      x_end = min((i + 1) * chunk_size, res)
      y_start = j * chunk_size
      y_end = min((j + 1) * chunk_size, res)

      grid_x, grid_y = torch.meshgrid(
          torch.arange(x_start, x_end, dtype=torch.float32),
          torch.arange(y_start, y_end, dtype=torch.float32),
          indexing='xy'
      )

      grid_x = grid_x.flatten()
      grid_y = grid_y.flatten()
      grid_coords = torch.stack([grid_x, grid_y], dim=-1)

      chunk_size = x_end - x_start

      grid_coords_batch = grid_coords.unsqueeze(0).expand(B, -1, -1)

      dists = torch.cdist(grid_coords_batch, data_coords)

      min_dists, _ = dists.min(dim=2)

      min_dist_chunk = min_dists.view(B, chunk_size, chunk_size)

      min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk.numpy()

  return min_dist_tensor_full

Here is how I visualise my result:

points = create_points(3, 4)
dist_data = min_dist(points, 65536)[:]

batch_size = dist_data.shape[0]
plt.figure(figsize=(12, 4 * batch_size))

for i in range(batch_size):
    plt.subplot(1, batch_size, i + 1)
    plt.imshow(dist_data[i], cmap='viridis')
    plt.title(f'Batch {i + 1}')

plt.tight_layout()
plt.show()