For now, I just have problem statement - I am trying to build the script for the problem statement… But, looking to find right torch APIs for efficient implementation…Let me try explain…
Input: X 4D tensor
“X” is collection of health records of 10 patients - record per patient is a tensor of n x m (meaning each patient had “n” visits to the hospital and “m” is max number of diagnosis tests that performed per visi…So, x looks like below
X[batch_size][number_of_visits][max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10][embedding_dimension]
Assuming below params:
batch_size = 16
no_of_visits = 10
max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10 = 20 (If any diagnosis code less than 20, that tensor will be padded with zeros - will talk about this mask below).
Embedding dimension = 100 (total records)
Mask: masks 3D tensor
masks[batch_size][number_of_visits][max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10]
Sample values of X and Masks:
X = [ [ [10, 8, 21], [18, 5] ],
[ [1, 3, 62, 17], [1], [3, 1] ] ]
bs = 1
no_of_visits = 2
max_diagnosis_performed_in_a_visit = 4 (i.e, len([1, 3, 62, 17]))
Padded "X" looks like below:
[ [ [10, 8, 21, *0*], [18, 5, *0*, *0*], [*0*, *0*, *0*, *0*] ],
[ [1, 3, 62, 17], [1, *0*, *0*, *0*], [3, 1, *0*, *0*] ] ]
So size of padded X = [2][4][12]
Masks tensor for above X will be:
[ [ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ],
[ [1, 1, 1, 1], [1, 0, 0, 0], [1, 1, 0, 0] ] ]
Now Question is:
For above given padded "X" and "Masks" - I need to perform below
1. Extract all last true visit (that is index of last non-padded value) of each patient:
Ex: for Patient 1: it will be [[3, 2]] as per mask of patient:1 "[ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ]"
last true visit per visit is 1st row 3rd col, 2 row 2nd col and no true visit in 3rd row
2. For above indices extracted for each patient - I need to use this indices for all patients per-batch size (above example just talks with batch size =1, but I might be have 16/32 as bs) - to construct a tensor from "X" using torch.gather()
Please let me know, if need more details.