Model for 3D voxel reconstruction from n 2D projections

Hello.
I am still a beginner with pyTorch.
I created a script that generates 3D primitives objects in a voxel representation.
The voxel is either transparent (outside the 3D object) or with a grayscale constant value and no transparency (inside the 3D object)
The primitives are spheres, oblate spheres, hexahedrons (stretched cubes), random tetrahedrons.
for each primitive a voxel is generated and passed to a 3D renderer as a data file (povray used here)
The 3D scene adds a light source with constant position.
the Python script generates a set of 4 camera positions (top,bottom,left,right) and for each voxel the camera are rotated by the same amount, so their positions remain constant relative to each other.
then the renderer generates 4 png images.
These images are fed to the model. The voxel generated previously is used as the target of the model.

The model used for now is a basic two linear transforms with a sigmoid activation in between.
For training, the criterion loss is MSELoss()
I have only 120 samples, for now, so training will probably be mediocre. I get a loss starting at 0.6. and converging to 0.4.

Any idea for a suitable model for this kind of task?
I can post the whole script if needed;

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        requires_grad = True
            # Inputs to hidden layer : 4 png images : 4*640*480*3 (4 images 640*480, 24bit color)
        self.hidden = nn.Linear(3686400, 16)
        # Output is a voxel of 16*16*16*2 (16 voxels 3D cube, with grayscale and transparency values)
        self.output = nn.Linear(16, 16*16*16*2)        
        
    def forward(self, x):
        x.requires_grad = True
        # Pass the input tensor through each of our operations
        x = self.hidden(x)
        x = F.sigmoid(x)
        x = self.output(x)
                       
        return x

I think I should experiment with nn.Conv3d

(How can this thing be back-propagable ? Out of curiosity, I would be interested to see how you compute the loss, if you don’t mind sharing the code.)
Another curious question : is the simplicity of the primitives (simple shapes) the reason why you chose such a small hidden size of 16 ?

Maybe a good idea would be to use some Conv2d layers to extract the image features, flatten the features, and then use some Conv3d layers to create the voxel image.

Your post reminds me of this paper : Unsupervised Learning of 3D Representations From Natural Images.
I don’t know if is fits your work exactly, but maybe you will find good ideas. They use a mix of 3D and 2D convolutions.

I will dig it. For now the model does not work at all. but the code scaffolding (generating different voxel shapes, rendering viewports left behind top and bottom with povray and supplying the voxel as target data is done) Just need to tidy the code a little bit, and I will post it there.

I will check the various examples of the GAN technique.
But before, I need to make the convolution layer right, and that means creating a tensor of (640 * 480 * 3,4)
for the 4 images, instead of flattening all images into (640 * 480 * 3 *4)

I think that the fastest approach would be do process each projection with conv2d to extract the edges of the shape. these would create the vertices of the shape.
Then reconstruct the 3d shape by creating faces linking the vertices found in each projection.
although it may work for polyhedra. I wonder how it will do for spheroids.
And I have no idea on how to tackle this with nn.
maybe this step could be done easily with a simple algebraic algorithm.
I will probe the stackoverflow forums and math forums.

Full code with conv2D. I checked the resulting voxels, the grayscale is somewhat biased, should be 0 for a voxel outside the shape.
Any idea welcome.

 #Worflow detail :

#0) generate batch of 50 voxels per category (category = spheroid,hexahedron (stretched cube),random tetrahedron). 
    #This will be the target (check) data
    #Voxel shape is 16,16,16,2 (resolution 16 per coordinate, and 2 entries for transparency and grayscale) transparency info is necessary for POV-RAY to render correctly
    #Algorithm could be made easier by just not rendering blocks with grayscale=0, now that I think of it.
    #Target data is a slice of the voxel with only grayscale.
    #Result of nn is a voxel of 16,16,16,1 (contains only grayscale)
#1) write voxel to dat file for each pov scene
#2) render each voxel with the pov scene with random camera rotation for each scene (same rotation per camera, 4 cameras : top,left,front
    #this will be the input feature. 400,400,3 (size 400, 3 channels, RGB)
    #use clock from 0 to 3 inclusive to render 3 png per voxel (for each view)
#3) Use conv2d to detect the edges of the voxel and reduce it to 16*16 image
#4) Pass it to ReLU
#4) Forward project each image (torch repeat)
#5) obtain X,Y,Z projections by transposing the matrix
#6) sum the matrices (high value mean a solid block, low value is outside the shape)
#7) Pass it to tanh activation function
#5) compute loss with criterion (MSELoss). output = voxel generated by network, target = voxel generated above at #0)
#6) Training data use 7 epochs wit a batch of 3 voxels sampled at random from the 150 voxels. (using custom sampler)


import imageio
import glob
import numpy as np
import os

import os.path

from shutil import copyfile
import math as m

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(threshold=10000)



from scipy.spatial import ConvexHull
from scipy.spatial import Delaunay

from matplotlib.path import Path

voxelres = 16
voxelspercat = 50
voxelscat = 3
np.random.seed(42)
image_height = 400
image_width = 400
proj_num = 3
channels_rgb = 3

def voxel_convexhull(u1,v1,w1,points,voxelres):

    hull = Delaunay(points)
    #print(hull.points)
    #print(hull.simplices)
    #print(points[hull.simplices])
    #print(hull.find_simplex([u1,v1,w1]))
  
    #hull_path = Path(points[hull.vertices])
    #hull_path = Path(hull.points)
    #print(hull.find_simplex([u1,v1,w1]))
    #quit()
    if hull.find_simplex([u1,v1,w1]) != -1:
        return [0,0.5]
    else:
        return [1,0]


def loadvoxel(inputvoxel_file):
    voxel = np.fromfile(inputvoxel_file,dtype=long,count=-1,sep='\n')
    return voxel

def writevoxel(nparray,outputvoxel_file):
    nparray.tofile(outputvoxel_file,sep=',\n')

def writecampos(nparray,outputcampos_file):
    nparray.tofile(outputcampos_file,sep=',\n')

def voxeldiff(invoxel1,invoxel2):
    outvoxel = np.absolute(invoxel1 -invoxel2)
    return outvoxel

totalin = 0
totalout = 0

def voxel_spheroid(u1,v1,w1,r,a,c,voxelres):
    uc1=  u1- voxelres/2
    vc1=  v1- voxelres/2
    wc1=  w1- voxelres/2
    if (((uc1*uc1)/(a*a) + (vc1*vc1)/(a*a) + (wc1*wc1)/(c*c)) < r*r):
        global totalin
        totalin += 1
        return [0,0.5]
    else:
        global totalout
        totalout += 1
        return [1,0]

def voxel_cube(u1,v1,w1,u2,v2,w2,voxelres):
    if (abs(u2) > abs(u1-voxelres/2)) and (abs(v2) > abs(v1-voxelres/2)) and (abs(w2) > abs(w1-voxelres/2)):
        return [0,0.5]
    else:
        return [1,0]




rootpovfile = "C:\\WPy64-3740\\voxel\\pov_voxel"
rootpovpath = "C:\\WPy64-3740\\voxel\\"

def povrender(povfileindex,rootfile):

    global rootpovpath
    global image_height
    global image_width
    global voxelspercat
    global voxelscat
    global proj_num

    files = glob.glob(rootpovpath + '*.png')
    if files:
        filesnumber = sum(1 for _ in files)
        if filesnumber == voxelspercat * voxelscat * proj_num:
            print("projections already generated")
            return
    
    destpovfile = rootfile + '_' + str(povfileindex) + '_.pov'
    copyfile(rootfile +'.pov',destpovfile)
    argsroot = 'clock='
    writecampos(np.random.randint(0,360,size=(3,1)),rootpovfile + '_rot.dat')
    #for view in range(0,4):
        #args = argsroot + str(view) + '.0'
    os.system('"C:\\Program Files\\POV-Ray\\v3.7\\bin\\pvengine64.exe" +W' + str(image_width)  + ' +H' + str(image_height) + ' /EXIT /RENDER ' + destpovfile)
    print('"C:\\Program Files\\POV-Ray\\v3.7\\bin\\pvengine64.exe" +W' + str(image_width)  + ' +H' + str(image_height) + ' /EXIT /RENDER ' + destpovfile)
   
    
def genvoxel(voxelfunc,voxelres,p1,p2,p3,filepath,fileindex):

    if os.path.isfile(filepath + '_out_' + str(fileindex) + '.npy'):

        voxel = np.load(filepath + '_out_' + str(fileindex) + '.npy')

    else:
        
        voxel = np.zeros([voxelres,voxelres,voxelres,2])
        for u in range(0,voxelres):
            for v in range(0,voxelres):
                for w in range(0,voxelres):
                    [transp,gray] = voxelfunc(u,v,w,p1,p2,p3,voxelres)
                    voxel[u,v,w,0] = transp
                    voxel[u,v,w,1] = gray
                    #print(str(u) + ' ' +str(v) + ' ' + str(w))
                    w += 1
                v += 1
            u += 1
        #print(voxel)
        #print(np.shape(voxel))
        writevoxel(voxel,filepath + '.dat')
        np.save(filepath + '_out_' + str(fileindex) + '.npy',voxel)
        povrender(fileindex,filepath)
        print('fileindex:' + str(fileindex))
        #print("in" + str(totalin))
        #print("out" + str(totalout))
        #quit()

    return voxel

def genvoxelconvex(voxelres,filepath,fileindex):

    if os.path.isfile(filepath + '_out_' + str(fileindex) + '.npy'):

        voxel = np.load(filepath + '_out_' + str(fileindex) + '.npy')

    else:

        sizev = 4
        points = np.random.randint(voxelres/2 -sizev,voxelres/2 + sizev,size=(16,3))
        #print(points)
        #quit()
        voxel = np.zeros([voxelres,voxelres,voxelres,2])
        for u in range(0,voxelres):
            for v in range(0,voxelres):
                for w in range(0,voxelres):
                    [transp,gray] = voxel_convexhull(u,v,w,points,voxelres)
                    voxel[u,v,w,0] = transp
                    voxel[u,v,w,1] = gray
                    w += 1
                v += 1
            u += 1
        #quit()
        writevoxel(voxel,filepath + '.dat')
        np.save(filepath + '_out_' + str(fileindex) + '.npy',voxel)
        povrender(fileindex,filepath)
        print('fileindex:' + str(fileindex))

    return voxel


###NOT USED###
def flatten_concat_png(path):

    global proj_num
    global voxelspercat
    global voxelscat
    
    idx = 0
    idx_voxel = 0
    total_num_img = voxelscat*voxelspercat
    im_all = np.zeros([total_num_img,640*480*3*proj_num])
    files = glob.glob(path + '\\*.png')
    files.sort(key=os.path.getmtime)
    for im_path in files:
         print(im_path)
         im = imageio.imread(im_path)
         print(im.shape)
         if idx == 0:
             im_four = im.flatten()
             idx += 1
             print("init " + str(idx))
         elif idx < proj_num - 1:
             im_flat = im.flatten()
             im_four = np.concatenate((im_four,im_flat),axis=0)
             idx +=1
             print("inf4 " + str(idx))
         else:
             im_flat = im.flatten()
             im_four = np.concatenate((im_four,im_flat),axis=0)
             im_all[idx_voxel,:] = im_four
             idx_voxel += 1
             idx = 0
             print(idx_voxel)
             print(idx)
             
             
    return im_all
### END NOT USED ### 

def generate_png_tensor(path):

    global proj_num
    global voxelspercat
    global voxelscat
    global image_width
    global image_height
    global channels_rgb

    idx = 0
    idx_voxel = 0
    total_num_img = voxelscat*voxelspercat
    im_all = np.zeros([total_num_img,channels_rgb,image_width,image_height,proj_num])
    im_four = np.zeros([channels_rgb,image_width,image_height,proj_num])
    files = glob.glob(path + '\\*.png')
    files.sort(key=os.path.getmtime)
    for im_path in files:
         #print(im_path)
         im = imageio.imread(im_path)
         im = np.transpose(im, (2,0,1))
         #print(im.shape)
         #quit()
         if idx == 0:
             #print("im_four")
             #print(im_four[:,0].shape)
             #print(im_four[:,1])
             im_four[:,:,:,idx] = im
             idx += 1
             #print("init " + str(idx))
         elif idx < proj_num - 1:
             im_four[:,:,:,idx] = im
             idx +=1
             #print("inf4 " + str(idx))
         else:
             im_four[:,:,:,idx] = im
             im_all[idx_voxel,:] = im_four
             idx_voxel += 1
             idx = 0
             #print(np.shape(im_four))
             #print(im_four)
             #print(idx_voxel)
             #print(idx)
             
             
    return im_all


"""

"""  

def gen_train_data(voxelsnum,filepath):

    allvoxels = np.zeros([voxelsnum*3,voxelres,voxelres,voxelres,2])
    voxellist = []
    voxelindex = 0
    
    for t in range(0,voxelsnum):

        voxel = genvoxelconvex(voxelres,filepath,voxelindex)
        voxellist.append(voxel[:,:,:,1])
        voxelindex += 1


    for t in range(0,voxelsnum):

        phase = m.pi*t/voxelsnum
        r = m.cos(phase+m.pi/3)*voxelres/2
        a = m.cos(phase+2*m.pi/3)*4
        c = m.cos(phase+m.pi)*4
        voxel = genvoxel(voxel_spheroid,voxelres,r,a,c,filepath,voxelindex)
        voxellist.append(voxel[:,:,:,1])
        voxelindex += 1

    for t in range(0,voxelsnum):

        phase = m.pi*t/voxelsnum
        u = m.sin(phase+m.pi/3)*voxelres/2
        v = m.sin(phase+2*m.pi/3)*voxelres/2
        w = m.sin(phase+m.pi)*voxelres/2

        voxel = genvoxel(voxel_cube,voxelres,u,v,w,filepath,voxelindex)
        voxellist.append(voxel[:,:,:,1])
        voxelindex += 1
        

    allvoxels = np.stack(voxellist,axis=0)
    #allimages = flatten_concat_png(rootpovpath)
    allimages = generate_png_tensor(rootpovpath)
    print ("voxelindex:" + str(voxelindex))
    return [allimages,allvoxels]
    
    #print(allvoxels[0,:,:,:,:])


[ims,vox] = gen_train_data(voxelspercat,rootpovfile)
ims = torch.from_numpy(ims)
vox = torch.from_numpy(vox)
#ims = ims.half()
#vox = vox.half()

#allimages = flatten_concat_png(rootpovpath)

#print(np.shape(ims))

from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler


def subtrain(tloader,netsub,filepath):

    #training routine, give trainloader and network

    criterion = torch.nn.MSELoss()# Optimizers require the parameters to optimize and a learning rate
    optimizer = torch.optim.SGD(netsub.parameters(), lr=0.003)
    epochs = 7


    print("will train")
    idx = 0             
    for e in range(epochs):
        running_loss = 0
        for image, label, in tloader:
           
            #quit()
            #transfo2 = transfo.view(1,transfo.shape[0]*transfo.shape[1]*transfo.shape[2])
            #converting batch of 10 transformation to 1 row with 160 features (concatenation of 10 transformations)
            optimizer.zero_grad()
            output = netsub(image.float())
            loss = criterion(output,label.float())
            #print(labels[0])
            #sleep(0.5)
            loss.backward()
            optimizer.step()
            #print("out shape")
            #print(output.size())
            output2 = torch.squeeze(output)
            #print(output2.size())
            print(loss.item())
            running_loss += loss.item()
            outputvoxel = output2.detach().numpy()
            transpvoxel = np.ones(np.shape(outputvoxel))
            transpvoxel = transpvoxel - outputvoxel
            outputvoxel = np.stack((outputvoxel,transpvoxel),axis=3)
            #print(np.shape(outputvoxel))
            
            writevoxel(outputvoxel,filepath + '_out_' + str(idx) + '.dat')
            idx += 1

            
        else:
            print(f"Training loss: {running_loss/len(tloader)}")
            

class OrderedListSampler(Sampler):
    r"""Samples elements in the order specified by the list

    Arguments:
        indices : a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def train(td,cd,nettotrain,rand,filepath):
    #print("train data size:" + str(td.size()))
    #print("target data size:" + str(cd.size()))
    train_dataset = torch.utils.data.TensorDataset(td,cd)

    indices1 = np.arange(0,50)# indices for spheroids
    indices2 = np.arange(50,100) # indices for hexahedra
    indices3 = np.arange(100,150) # indices for tetrahedra

    permut = 0
    batch_size = 1
    firstload = 1
    #random permutations between stretches,rotations and translations    
    for batch in np.arange(0,150):
        for permut in np.random.permutation(3):
            batch_start = np.random.randint(50*permut,50*(permut+1) -batch_size,1)
            indices = np.arange(batch_start, batch_start + batch_size)
            if firstload ==1:
                indicesall = indices
                firstload = 0
            else:
                indicesall = np.concatenate((indicesall,indices),axis=0)
    #print(np.shape(indicesall))
    print(indicesall)
    #print(permut)
    #quit()

    if rand:
        sampler = OrderedListSampler(indicesall)           
        trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=0, sampler=sampler)
        subtrain(trainloader,nettotrain,filepath)



class Net(nn.Module):
    def __init__(self):

        global image_width
        global image_height
        global channels_rgb
    
        super().__init__()
        requires_grad = True
        #with torch.no_grad():
            # Inputs to hidden layer linear transformation
        

        self.hidden1 = nn.Conv2d(3,1,7,stride=2, padding=3, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.hidden2 = nn.Conv2d(1,1,7,stride=2, padding=3, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.hidden3 = nn.Conv2d(1,1,3,stride=2, padding=4, dilation=1, groups=1, bias=True, padding_mode='zeros')
        
        

        #((W-F+2*P )/S)+1
        
        
        # Define sigmoid activation and softmax output 
        #self.tanh = nn.Tanh()
        #self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x.requires_grad = True
        # Pass the input tensor through each of our operations
        #with torch.no_grad():

        xa = self.hidden1(x[:,:,:,:,0])
        xa = self.hidden2(xa)
        xa = self.hidden2(xa)
        xa = self.hidden2(xa)
        xa = self.hidden3(xa)
        
        xb = self.hidden1(x[:,:,:,:,1])
        xb = self.hidden2(xb)
        xb = self.hidden2(xb)
        xb = self.hidden2(xb)
        xb = self.hidden3(xb)

        xc = self.hidden1(x[:,:,:,:,2])
        xc = self.hidden2(xc)
        xc = self.hidden2(xc)
        xc = self.hidden2(xc)
        xc = self.hidden3(xc)
        
        xa = self.relu(xa)
        xb = self.relu(xb)
        xc = self.relu(xc)


        xa = xa.repeat(1,16,1,1)
        xb = xb.repeat(1,16,1,1)
        xc = xc.repeat(1,16,1,1)


        xb = torch.transpose(xb, 1, 2)
        xc = torch.transpose(xc, 2, 3)
        
        vx = xa + xb + xc

        x = self.tanh(vx)

        return x

net1 = Net()
train(ims,vox,net1,1,rootpovfile)

The model does not work so well.
I combined the forward projections (using torch repeat) using orthogonal vectors by transposing the projection matrices (torch transpose) and then adding them. I think that the best would be to make a logical and operation or using a better activation function.
There is a problem too with conv2d. It detects edges, but do not make the inside of the shape made by edges full (a positive grayscale value), and so it makes a projection that is not valid.

If anyone has an idea on how to detect shapes and assign a positive grayscale value inside of the enclosed face created from the edges, that would be nice.
On the other hand, using a vertices + faces reconstruction after conv2d could be more precise, but I still haven’t figured how to do it.

Hey are you still working on this project?

Did you ever get a good prediction of the 3D Voxel files from just using 2D projections. The shapes are simple, so I think it should work. I am curious to know about the loss computation though. The backpropagation isnt impossible, but does require some though.

I noticed you used MSELoss, did you ever try any others?