I’m trying to implement ATSS (Adaptive training sample selection) in Pytorch using the smallest amount of code possible. Here is my implementation so far:
from einops import repeat, unpack
def assign_atss (
anchors, # Tensor[N,4]
targets, # Tensor[B,D,5]
ps, # List[3]
nc, # int
topk # int
):
# Prep
B, D = targets.shape[:2]
mask_gt = targets[...,-1] > -1
# Calculate all IOUs and distances between targets and anchors
ious = torchvision.ops.box_iou(targets[...,:4].reshape([-1, 4]), anchors).reshape([B,D,-1]) # [B,D,N]
sxy = (anchors[:,0:2] + anchors[:,2:4]) / 2
tlt, trb = targets[...,0:2], targets[...,2:4]
txy = (tlt + trb) / 2
dists = torch.cdist(txy.reshape([-1,2]), sxy).reshape([B,D,-1])
# Select topk for each level
indices = []
offset = 0
for distl, num in zip(unpack(dists, ps, 'b a *'), ps):
indices.append(distl.topk(topk, dim=-1, largest=False)[1] + offset)
offset += num[0]
indices = torch.cat(indices, -1)
# IOU thresh mask
ious_idx = ious.gather(2, indices)
thresh = ious_idx.mean(2, keepdim=True) + ious_idx.std(2, keepdim=True)
mask_iou = ious > thresh
# Centre mask
bbox_deltas = torch.cat((sxy[None,None] - tlt.unsqueeze(2), trb.unsqueeze(2) - sxy[None,None]), -1)
mask_centre = bbox_deltas.amin(3) > 1e-9
# Best IOU for each anchor
indices = (ious * mask_centre).argmax(1)
indices = repeat(indices, 'b n -> b n f', f=5)
# Combine everything
mask = mask_gt.unsqueeze(-1) * mask_iou * mask_centre
targets = targets.gather(1, indices)
targets_bbox = targets[...,:4]
targets_cls = F.one_hot(targets[...,4].long(), num_classes=nc)
target_obj = mask.gather(1, indices[...,0:1])
targets = torch.cat((targets_bbox, target_obj.float(), targets_cls), -1)
return targets
Where anchors
has N anchor points (= h*w*a
) where a == 3
and is the number of priors per cell, and dim 4
contains values x0,y0,x1,y1.
targets
has a maximum of D
ground truth labels per batch and is padded with -1
values. The last dim of 5 contains values x0,y0,x1,y1,cls_index.
ps
is a list of 3 numbers containing the number of detections per feature pyramid level. For example if your levels stride by 8, 16 and 32 and N == 10647
then ps = [8112, 2028, 507]
.
Something somewhere is wrong since i get target_obj
always being full of False
values.
Can someone help?
I don’t really want to copy paste implementations from mmdetection, ultralytics or other repositories. The point here is to come up with a tiny implementation of ATSS in pure Pytorch while keeping it fairly readable.
Cheers