ATSS implementation help

I’m trying to implement ATSS (Adaptive training sample selection) in Pytorch using the smallest amount of code possible. Here is my implementation so far:

from einops import repeat, unpack

def assign_atss (
    anchors, # Tensor[N,4]
    targets, # Tensor[B,D,5]
    ps,      # List[3]
    nc,      # int
    topk     # int
):
    # Prep
    B, D    = targets.shape[:2]
    mask_gt = targets[...,-1] > -1
    
    # Calculate all IOUs and distances between targets and anchors
    ious        = torchvision.ops.box_iou(targets[...,:4].reshape([-1, 4]), anchors).reshape([B,D,-1]) # [B,D,N]
    sxy         = (anchors[:,0:2] + anchors[:,2:4]) / 2
    tlt, trb    = targets[...,0:2], targets[...,2:4]
    txy         = (tlt + trb) / 2
    dists       = torch.cdist(txy.reshape([-1,2]), sxy).reshape([B,D,-1])

    # Select topk for each level
    indices = []
    offset  = 0
    for distl, num in zip(unpack(dists, ps, 'b a *'), ps):
        indices.append(distl.topk(topk, dim=-1, largest=False)[1] + offset)
        offset += num[0]
    indices = torch.cat(indices, -1)

    # IOU thresh mask
    ious_idx = ious.gather(2, indices)
    thresh   = ious_idx.mean(2, keepdim=True) + ious_idx.std(2, keepdim=True)
    mask_iou = ious > thresh

    # Centre mask
    bbox_deltas = torch.cat((sxy[None,None] - tlt.unsqueeze(2), trb.unsqueeze(2) - sxy[None,None]), -1)
    mask_centre = bbox_deltas.amin(3) > 1e-9

    # Best IOU for each anchor
    indices = (ious * mask_centre).argmax(1)
    indices = repeat(indices, 'b n -> b n f', f=5)

    # Combine everything
    mask            = mask_gt.unsqueeze(-1) * mask_iou * mask_centre
    targets         = targets.gather(1, indices)
    targets_bbox    = targets[...,:4]
    targets_cls     = F.one_hot(targets[...,4].long(), num_classes=nc)
    target_obj      = mask.gather(1, indices[...,0:1])
    targets         = torch.cat((targets_bbox, target_obj.float(), targets_cls), -1)

    return targets

Where anchors has N anchor points (= h*w*a) where a == 3 and is the number of priors per cell, and dim 4 contains values x0,y0,x1,y1.
targets has a maximum of D ground truth labels per batch and is padded with -1 values. The last dim of 5 contains values x0,y0,x1,y1,cls_index.
ps is a list of 3 numbers containing the number of detections per feature pyramid level. For example if your levels stride by 8, 16 and 32 and N == 10647 then ps = [8112, 2028, 507].

Something somewhere is wrong since i get target_obj always being full of False values.
Can someone help?
I don’t really want to copy paste implementations from mmdetection, ultralytics or other repositories. The point here is to come up with a tiny implementation of ATSS in pure Pytorch while keeping it fairly readable.

Cheers

I think the following is correct. I missing mask_topk.

@torch.no_grad()
def assign_atss (
    anchors, # Tensor[N,4]
    targets, # Tensor[B,D,5]
    ps,      # List[3]
    nc,      # int
    topk     # int
):
    # Prep
    B, D, N, device = targets.shape[0], targets.shape[1], anchors.shape[0], anchors.device
    
    # Calculate all IOUs and distances between targets and anchors
    ious        = torchvision.ops.box_iou(targets[...,:4].reshape([-1, 4]), anchors).reshape([B,D,-1]) # [B,D,N]
    sxy         = (anchors[:,0:2] + anchors[:,2:4]) / 2
    tlt, trb    = targets[...,0:2], targets[...,2:4]
    txy         = (tlt + trb) / 2
    dists       = torch.cdist(txy.reshape([-1,2]), sxy).reshape([B,D,-1])

    # Select topk for each level
    indices = []
    offset  = 0
    for distl, num in zip(unpack(dists, ps, 'b a *'), ps):
        indices.append(distl.topk(topk, dim=-1, largest=False)[1] + offset)
        offset += num[0]
    indices = torch.cat(indices, -1)
    mask_topk = torch.zeros((B, D, N), dtype=torch.bool, device=device)
    mask_topk.scatter_(2, indices, torch.ones_like(indices, dtype=torch.bool))

    # IOU thresh mask
    ious_idx = ious.gather(2, indices)
    thresh   = ious_idx.mean(2, keepdim=True) + ious_idx.std(2, keepdim=True, correction=0)
    mask_iou = ious > thresh

    # Centre mask
    bbox_deltas = torch.cat((sxy[None,None] - tlt.unsqueeze(2), trb.unsqueeze(2) - sxy[None,None]), -1)
    mask_centre = bbox_deltas.amin(3) > 1e-9

    # Combine masks
    mask_gt = (targets[...,-1] > -1).unsqueeze(-1)
    mask    = mask_gt * mask_centre * mask_iou * mask_topk

    # Best IOU for each anchor
    scores, indices = (ious * mask).max(1)
    mask = scores > 1e-9
    
    # Combine
    targets_box     = targets.gather(1, repeat(indices, 'b n -> b n f', f=4))
    targets_cls     = targets[...,4].long().gather(1, indices)
    targets_cls     = F.one_hot(targets_cls, num_classes=nc)
    targets_score   = mask.unsqueeze(-1).float()
    targets         = torch.cat([targets_box, targets_score, targets_cls], -1)
    targets[~mask] = 0

    return targets

This implementation is still quite slow.
Possible reasons:

  • We’re calculating IOUs on all anchors, not just the topk candidates.
  • The line targets[~mask] = 0 takes a long time to complete. I’ve seen it take longer than the rest of the algorithm. It’s strictly not necessary since the mask (targets[...,4] > 0) is correct.