Segmentation of Organ At Risk using 3D Unet

I am CSE student, doing my Minor Project in PyTorch on Segmentation of Organ At Risk using CT Scans and 3D Unet.

I am very new to PyTorch and Deep Learning in general. I have read first 8 chapters of the book “Deep Learning with PyTorch” to learn about PyTorch. I have created a demo jupyter notebook for my project but it has some errors due to which the GPU runs out of memory.

Context of Project In Brief:

  1. Takes data from GDrive.
  2. Creates file paths for the volumes and labels
  3. Created two CustomDatasets to handle 3D Volumes and its Patches
  4. Implemented 3D Unet From Scratch.
  5. Created some Metric and Losses functions (a bit of doubt here on which losses and metrics to use for 3D segmentation)
  6. After this the Training Loop

I will add MarkDown of these to file for better visibility

I have a few questions :

  1. Is my approach good enough?
  2. Is my 3D Unet implementation correct?
  3. What losses and metrics should I consider for 3D segmentation?
  4. Is my training loop correct?
  5. I am running out of GPU? can you please suggest some code changes to tackle it?
  6. How can I improve my code

Link to my Jupyter Notebook Here.