Hi,
I’m trying to reproduce the paper:
Counting Out Time: Class Agnostic Video Repetition Counting in the Wild
https://openaccess.thecvf.com/content_CVPR_2020/papers/Dwibedi_Counting_Out_Time_Class_Agnostic_Video_Repetition_Counting_in_the_CVPR_2020_paper.pdf
After the first batch I’m getting nan as outputs, my loss.grad=None.
My inputs are normalized.
I’m adding a snippet of the code.
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-3])
self.temporal_context = nn.Sequential(nn.Conv3d(in_channels=1024,
out_channels=512,kernel_size=3,padding=[1,0,0]), nn.ReLU())
def forward(self,x):
b, f, c, h, w = x.size()
features = self.cnn(x.view(b*f, c, h, w))
temp_features = self.temporal_context(features.unsqueeze(0).view(b,features.shape[1], f,features.shape[2],features.shape[3]))
out = nn.AvgPool2d(5)(temp_features.view(b*f,512, temp_features.shape[-1], temp_features.shape[-1])).view(b, f, -1)
return out
class PeriodPredictor(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3), nn.ReLU())
self.projection = nn.Linear(1922,512)
self.transformer = nn.Transformer(nhead=4,dim_feedforward=128)
def forward(self, x):
out = self.cnn(x)
out = out.view(x.size(0), 64, -1)
out = self.projection(out)
out = self.transformer(out, out)
return out
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.l_network = nn.Sequential(nn.Linear(512,512),nn.ReLU(),nn.Linear(512,32))
self.p_network = nn.Sequential(nn.Linear(512,512),nn.ReLU(),nn.Linear(512,1))
def forward(self, x):
b, f, e = x.size()
l = self.l_network(x.view(b*f,-1))
p = self.p_network(x.view(b*f,-1))
return l.view(b,-1), p.view(b,-1)
class RepNet(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.period_predictor = PeriodPredictor()
self.classifier = Classifier()
self.optimizer = torch.optim.SGD(params=self.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
def forward(self, x):
features = self.encoder(x)
sim_mat = -torch.matrix_power(torch.cdist(features, features), 2).unsqueeze(1)
period = self.period_predictor(sim_mat)
l, p = self.classifier(period)
return l, p
def train_epoch(self, loader, epoch, device):
self.train()
print(
"\n" + " " * 10 + "****** Epoch {epoch} ******\n"
.format(epoch=epoch)
)
training_losses = []
mae = deque(maxlen=30)
self.optimizer.zero_grad()
with tqdm(total=len(loader), ncols=80) as pb:
for batch_idx, d in enumerate(loader):
frames, l, p = d
frames, l, p = frames.to(device), l.to(device), p.to(device)
frames.requires_grad = True
self.optimizer.zero_grad()
l_logits, p_logits = self.forward(frames)
loss = torch.nn.BCELoss()(torch.nn.Sigmoid()(p_logits.view(-1)),p.float().view(-1))
loss += torch.nn.CrossEntropyLoss()(l_logits.view(frames.shape[0]*64,32), l.view(-1))
training_losses.append(loss.data.cpu().numpy())
loss.backward()
clip_gradient(self.optimizer, 0.1)
self.optimizer.step()
counts_t = []
counts_p = []
p_preds = torch.where(torch.nn.Sigmoid()(p_logits) > 0.5, torch.tensor(1).cuda(), torch.tensor(0).cuda())
l_preds = l_logits.view(frames.shape[0],64,32).argmax(2)
for l_i, p_i, ll_i, pl_i in zip(l, p, l_preds, p_preds):
reps = torch.where(l_i>0)[0]
counts_t.append( torch.sum(torch.div(p_i[reps].float(),l_i[reps].float())).data.cpu().numpy())
counts_p.append( torch.sum(torch.div(pl_i[reps].float(),ll_i[reps].float())).data.cpu().numpy())
try:
mae_i = mean_absolute_error(np.array(counts_t), np.array(counts_p))
mae.append(mae_i)
pb.update()
pb.set_description(
f"Loss: {loss.item()}, MAE: {np.mean(mae)}")
except Exception as e:
print(e)
continue
def validate(self, loader):
self.eval()
with tqdm(total=len(loader), ncols=80) as pb:
for batch_idx, frames, l, p in enumerate(loader):
l_logits, p_logits = self.forward(frames)
loss = torch.nn.CrossEntropyLoss()(l_logits, l) + torch.nn.CrossEntropyLoss()(p_logits, p)
pb.update()
pb.set_description(
"Loss: {:.4f}".format(
loss.item()))```