I am trying to build a function that does as follows:
“”"
Calculates a 3D tensor with the minimum distance from each pixel to data.
Inputs:
* points: a numpy array of 2D coordinates and point values, normalized to
be in the range [0,1]. The expected shape is [B, P, 2].
* res: the resolution of the output tensor.
Returns:
A res x res square tensor with floating point values corresponding to the
euclidean distance to the closest point in points. The return tensor
shape is [B, res, res].
“”"
However, it should be able to process a resolution of up to 65,536. I have tried the following method, but it is really slow. Other things I have tried and I run out of RAM. Is there a better way to deal with large tensors like this, or perhaps is my actual algorithm too slow?
Here is my code, I am doing chunk by chunk and writing it to disk:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
def create_points(batch_size, num_points):
coords = np.random.rand(batch_size, num_points, 2)
return coords
def min_dist(points, res):
data_coords = torch.tensor(points, dtype=torch.float32) * (res - 1)
B, P, _ = data_coords.shape
chunk_size = 4096
# Calculate number of chunks
num_chunks = (res + chunk_size - 1) // chunk_size
# min_dist_tensor_full = torch.zeros(B, res, res)
min_dist_tensor_full = np.memmap('large_tensor.dat', dtype='float32', mode='w+', shape=(B, res, res))
# Initialize or load the final tensor
for i in range(num_chunks):
for j in range(num_chunks):
x_start = i * chunk_size
x_end = min((i + 1) * chunk_size, res)
y_start = j * chunk_size
y_end = min((j + 1) * chunk_size, res)
grid_x, grid_y = torch.meshgrid(
torch.arange(x_start, x_end, dtype=torch.float32),
torch.arange(y_start, y_end, dtype=torch.float32),
indexing='xy'
)
grid_x = grid_x.flatten()
grid_y = grid_y.flatten()
grid_coords = torch.stack([grid_x, grid_y], dim=-1)
chunk_size = x_end - x_start
grid_coords_batch = grid_coords.unsqueeze(0).expand(B, -1, -1)
dists = torch.cdist(grid_coords_batch, data_coords)
min_dists, _ = dists.min(dim=2)
min_dist_chunk = min_dists.view(B, chunk_size, chunk_size)
min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk.numpy()
return min_dist_tensor_full
Here is how I visualise my result:
points = create_points(3, 4)
dist_data = min_dist(points, 65536)[:]
batch_size = dist_data.shape[0]
plt.figure(figsize=(12, 4 * batch_size))
for i in range(batch_size):
plt.subplot(1, batch_size, i + 1)
plt.imshow(dist_data[i], cmap='viridis')
plt.title(f'Batch {i + 1}')
plt.tight_layout()
plt.show()