I am trying to perform some indexing operations on the features of attention calculation. Below is a simple test code that maps the initial feature feat to [feat, feat], then performs
normal attention calculation, and finally inverses it back to the original length. In this process, multiple points are mapped to the same position, so I take the average method.
However, this causes NaN during backpropagation. I checked the feature values in forward propagation each time and found no NaN. How should I correctly perform the above indexing
operations?
def forward(self, point):
...
if int(offset2bincount(point.offset).min()) <= self.patch_sizes[self.stage_idx]:
...
else:
a = torch.arange(0, len(L0)).to(L0.device)
b = torch.arange(0, len(L0)).to(L0.device)
L0_mapped = torch.cat([a, b]).to(L0.device)
L0_sorted = L0[L0_mapped] # [feat, feat]. here L0 == feat
cu_seqlens = torch.arange(1, first_cu_seqlens.shape[0] * 2) * self.patch_sizes[self.stage_idx]
cu_seqlens = torch.cat([torch.tensor([0]), cu_seqlens]).to(
dtype=torch.int32).cuda()
H = self.num_heads
K = self.patch_sizes[self.stage_idx]
C = self.channels
q = self.q(L0_sorted)
kv = self.kv(L0_sorted)
feat = flash_attn.flash_attn_varlen_kvpacked_func(
q.half().reshape(-1, H, C // H),
kv.half().reshape(-1, 2, H, C // H),
cu_seqlens,
cu_seqlens,
max_seqlen_q=self.patch_sizes[self.stage_idx],
max_seqlen_k=K,
dropout_p=self.attn_drop if self.training else 0,
softmax_scale=self.scale,
).reshape(-1, C)
feat = feat.to(q.dtype)
if int(offset2bincount(point.offset).min()) > self.patch_sizes[self.stage_idx]:
target_size = len(L0)
feat_inversed = torch.zeros(target_size, feat.size(1), dtype=feat.dtype, device=feat.device)
L0_mapped_expanded = L0_mapped.unsqueeze(1).expand(-1, self.channels)
feat_inversed.scatter_add_(0, L0_mapped_expanded, feat) # [feat, feat] -> feat
counts = torch.zeros(target_size, feat.size(1), dtype=feat.dtype, device=feat.device)
counts.scatter_add_(0, L0_mapped_expanded, torch.ones_like(feat))
feat = feat_inversed / counts # Avg
feat = feat[first_inverse]
# ffn
...
if torch.isnan(feat).any() or torch.isinf(feat).any():
raise ValueError("feat contains NaN or inf.")
return point