Greetings, I want to extract the nonzero elements in the 2-dimensional tensor.
example
# tensor([[ 0, 4535, 0, ..., 9446, 5344, 0],
# [ 0, 1037, 0, ..., 2391, 33243, 40207]])
result
# tensor([[4535, ..., 9446, 5344],
# [1037, ..., 2391, 33243]])
I thought it was a simple problem, but I countered the error because of the mismatch of shape.
I tried as below. Does anyone have a good idea?
result[:, torch.cat((result[0].nonzero()[:, 0], result[1].nonzero()[:, 0])).unique()]
#tensor([[ 4535, 34602, 31080, 2097, 30280, 30764, 0, 30471, 5728, 32816, #Unexpected 0
# 2012, 0, 6726, 30444, 0, 34984, 32871, 32039, 0, 32528, ...], #Unexpected 0
# [ 1037, 35914, 9275, 7511, 9624, 9925, 7217, 9039, 1386, 362,
# 2022, 2719, 0, 6996, 31594, 34346, 33777, 34213, 37084, 32728,...]]) #Unexpected 0
Thank you for reading this question.