takiyu
(takiyu)
1
My question is the way of row-wise masking in batch manner.
I wrote non-batch case in the following.
import torch
n_dim0, n_dim1 = 4, 100
threshold = 0.5
storage = torch.rand((n_dim0, n_dim1))
mask = (threshold < storage)
print('storage:', storage.shape) # [4, 100]
print('mask: ', mask.shape) # [4, 100]
mask_lens = torch.sum(mask, axis=1) # [4]
max_mask_lens = torch.max(mask_lens) # [1]
masked = torch.zeros((n_dim0, max_mask_lens))
masked[0, 0:mask_lens[0]] = storage[0, mask[0]] #
masked[1, 0:mask_lens[1]] = storage[1, mask[1]] # How to make these lines
masked[2, 0:mask_lens[2]] = storage[2, mask[2]] # batch operation?
masked[3, 0:mask_lens[3]] = storage[3, mask[3]] #
print('masked: ', masked.shape) # [4, X] (X is around 50)
I wish someone has great idea.
Thank you.
takiyu
(takiyu)
2
One tricky idea has come up. However this is not smart and not effective.
import torch
n_dim0, n_dim1 = 4, 100
threshold = 0.5
storage = torch.rand((n_dim0, n_dim1))
mask = (threshold < storage)
print('storage:', storage.shape) # [4, 100]
print('mask: ', mask.shape) # [4, 100]
mask_lens = torch.sum(mask, axis=1) # [4]
max_mask_lens = torch.max(mask_lens) # [1]
# Create mask indices
mask_idxs = torch.nonzero(mask) # [X, 2] (2 == storage.ndim)
# Append padding to `mask_idxs`
pad_sizes = max_mask_lens - mask_lens
pads = []
for row_idx, pad_size in enumerate(pad_sizes):
pad = torch.zeros((pad_size, storage.ndim), dtype=torch.long)
pad[:, 0] = row_idx
pads.append(pad)
mask_idxs = torch.cat([mask_idxs] + pads, axis=0)
# Sort `mask_idxs` by first element
mask_idxs_argssort = torch.argsort(mask_idxs[:, 0])
mask_idxs = mask_idxs[mask_idxs_argssort]
# Masking
masked = storage[mask_idxs[:, 0], mask_idxs[:, 1]]
# Restore original dimension
masked = masked.reshape(n_dim0, -1)
print('masked: ', masked.shape) # [4, X] (X is around 50)