Selecting indices based on values of tensor and then use it for torch.gather

I have two tensors
X: torch.Size([16, 10, 100])
Mask: torch.Size([16, 10, 20])

16 → Batch_size
10 → total ID’s
100 → Total events for each ID
20 → Length of each event

Requirement: How to create 1D indices matrix - i.e., indices will have/store last index value of mask tensor at which its cell value = 1. For ex: Mask[0][0] = [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] → last index at which value = 1 is indx=3. So, I need to have 1D array of indices as = [[3,…]]

Once I have 1D indices I need to select values in “X” at those indices using torch.gather.
Appreciate any pointers…on how I can use this efficiently.

Hi Pradeep!

Could you post a fully-self-contained, runnable script that performs an
example of the calculation you want to do, using python for loops and
if blocks, as necessary, together with its output?

Best.

K. Frank

For now, I just have problem statement - I am trying to build the script for the problem statement… But, looking to find right torch APIs for efficient implementation…Let me try explain…

Input: X 4D tensor

“X” is collection of health records of 10 patients - record per patient is a tensor of n x m (meaning each patient had “n” visits to the hospital and “m” is max number of diagnosis tests that performed per visi…So, x looks like below

X[batch_size][number_of_visits][max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10][embedding_dimension]

Assuming below params:
batch_size = 16
no_of_visits = 10
max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10 = 20 (If any diagnosis code less than 20, that tensor will be padded with zeros - will talk about this mask below).
Embedding dimension = 100 (total records)

Mask: masks 3D tensor

masks[batch_size][number_of_visits][max.no.of.diagnosis.performed.per.visit_for_a_patient_in_the_group_of_10]

Sample values of X and Masks:
X = [ [ [10, 8, 21], [18, 5] ],
[ [1, 3, 62, 17], [1], [3, 1] ] ]
bs = 1
no_of_visits = 2
max_diagnosis_performed_in_a_visit = 4 (i.e, len([1, 3, 62, 17]))


Padded "X" looks like below:
[ [ [10, 8, 21, *0*], [18, 5, *0*, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [1, 3, 62, 17], [1, *0*, *0*, *0*], [3, 1, *0*, *0*] ] ]

So size of padded X = [2][4][12]

Masks tensor for above X will be:
[ [ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 1, 1], [1, 0, 0, 0], [1, 1, 0, 0] ] ]

Now Question is: 
For above given padded "X" and "Masks" - I need to perform below

1. Extract all last true visit (that is index of last non-padded value) of each patient:

Ex: for Patient 1: it will be [[3, 2]] as per mask of patient:1 "[ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ]" 
last true visit per visit is 1st row 3rd col, 2 row 2nd col and no true visit in 3rd row

2. For above indices extracted for each patient - I need to use this indices for all patients per-batch size (above example just talks with batch size =1, but I might be have 16/32 as bs) -  to construct a tensor from  "X"  using torch.gather()

Please let me know, if need more details.

If values 1 in your mask is contiguous (there is not 0 between two 1s) and all of each vector starts with 1, I mean, only [1, 1, 0, 0, 0] is supported, but not [0, 1, 1, 0, 0] or [1, 0, 1, 0, 0], you can try my following code.

indices = torch.sum(mask, dim=-1, keepdim=True) - 1

Thank you for the code snippet Eta_C…I tried that- But, output size is not exactly what i am looking for - output size from code snipper is indices = torch.Size([16, 10, 1]) …But, I am looking for indices with size [1, 16]

i.e.

  1. Input masks[batch_size][visits][total diagnosis]
  2. Flatten it to mask[batch_size][visits * diagnosis]
  3. Sum each row in mask[bastch_sizer] (i.e. all the (visits * diagnosis) values with each other per batch_size)
  4. this will give me a final matrix/tensor of size mask[1, 16]

Notes: Yes masks are contiguous

Able to implement what I am looking for.
Not able to share the implementation here now…due to restrictions…But used below APIs

torch.sum(x, dim=2)
torch.view
torch.expand
torch.gather
torch.squeeze