Hi everyone. I’m currently trying to debug an issue with ONNX exporting and I’m currently trying to get ideas on how to further debug the problem.
I’m currently on PyTorch 2.0.
I’ve included an minimum runnable example below in case you’re interested in testing it on your own setup.
Context: I’m trying to export a network with two nn.Module
s, that will both be replaced by ONNX functions.
These functions are used to abstract out a data-dependent computational graphs, so that it can later be replaced by
a custom op, on the target hardware. This ensures that each nn.Module
receives and outputs a fixed sized input.
Problem: When I connect the output of FixedShapeUnique
to the input of CreateVoxelGrid
is where problems start.
FixedShapeUnique
reports having four output tensors instead of two and CreateVoxelGrid
reports six input tensors
instead of four.
If I don’t pass CreateVoxelGrid
as an argument into export_modules_as_functions
, FixedShapeUnique
still reports
having four output tensors instead of two.
If I don’t use export_modules_as_functions
, I also don’t see these bogus nodes appearing.
Any ideas on what is happening or where to look next?
Code:
from typing import Optional
import torch
from torch import nn
class CreateVoxelGrid(nn.Module):
def __init__(self, shape: tuple[int, int, int, int]) -> None:
super().__init__()
self.grid_shape = shape
def forward(
self,
voxel_features: torch.Tensor,
indices: torch.Tensor,
voxel_features_mask: Optional[torch.Tensor] = None,
indices_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
grid = voxel_features.new_zeros(self.grid_shape)
if voxel_features_mask is not None:
voxel_features = voxel_features[voxel_features_mask]
if indices_mask is not None:
indices = indices[indices_mask]
grid[indices[:, 0], indices[:, 1], indices[:, 2]] = voxel_features
return grid
class FixedShapeUnique(nn.Module):
def forward(
self,
tensor: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if mask is None:
mask = torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device)
output = torch.zeros_like(tensor)
valid = torch.zeros_like(mask)
unique_tensor = torch.unique(tensor[mask], dim=0)
output[: unique_tensor.shape[0]] = unique_tensor
valid[: unique_tensor.shape[0]] = True
return output, valid
class Network(nn.Module):
def __init__(self, grid_shape: tuple[int, int, int, int]) -> None:
super().__init__()
self.unique = FixedShapeUnique()
self.voxel_grid = CreateVoxelGrid(grid_shape)
def forward(
self,
voxel_features: torch.Tensor,
indices: torch.Tensor,
voxel_features_mask: Optional[torch.Tensor] = None,
indices_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
indices, indices_mask = self.unique(indices, mask=indices_mask) # <- the million dollar question
return self.voxel_grid(
voxel_features, indices, voxel_features_mask=voxel_features_mask, indices_mask=indices_mask
)
def main():
torch.manual_seed(24)
channels = 8
n_occupied_voxels = 20
voxel_features = torch.randn(n_occupied_voxels, channels)
batch_size = 1
grid_shape = (batch_size, 256, 256, channels)
indices = torch.stack([torch.randint(size, size=(n_occupied_voxels,)) for size in grid_shape], dim=1)
voxel_features_mask = torch.rand(voxel_features.shape[0]) > 0.5
# just creating a new mask with the same number of True elements
indices_mask = torch.flipud(voxel_features_mask)
model = Network(grid_shape)
model(voxel_features, indices, voxel_features_mask=voxel_features_mask, indices_mask=indices_mask)
path = "/tmp/playground.onnx"
torch.onnx.export(
model=model.eval(),
args=(voxel_features, indices, {"voxel_features_mask": voxel_features_mask, "indices_mask": indices_mask}),
f=path,
opset_version=15,
input_names=["voxel_features", "indices", "voxel_features_mask", "indices_mask"],
export_modules_as_functions={FixedShapeUnique, CreateVoxelGrid},
)
if __name__ == "__main__":
main()
Sérgio