Training procedure stuck

Hi I’m trying to train a basic classifier.

my models is:

class Model(nn.Module):
    def __init__(self,input_size=512,output_size=3, hidden_size=512):
        super(Model, self).__init__()
        self.cnn = CNN()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,bidirectional=True)
        self.hidden_size = hidden_size
        self.linear = nn.Sequential(nn.Linear(hidden_size*2,hidden_size), nn.ReLU(), nn.Linear(hidden_size,output_size),
                                    nn.Dropout(0.2))
        torch.nn.init.xavier_normal_(self.linear[0].weight, gain=1.0)
        torch.nn.init.xavier_normal_(self.linear[2].weight, gain=1.0)


    def forward(self,x,indices):
        features = self.cnn(x)
        num_samp = torch.unique(indices)
        preds = []
        for i in num_samp:
            p,_=self.lstm(features[torch.where(indices == i)[0]].unsqueeze(1),
                      (torch.zeros((2, 1, 512)).cuda(), torch.zeros((2, 1, 512)).cuda()))
            out = self.linear(p.squeeze(1))
            preds.append(out)
        preds = torch.stack(preds).squeeze(0)
        return preds

My dataset code is:

class AudioDataset(data.Dataset):
    def __init__(self,root,indices):
        super(AudioDataset,self).__init__()
        self.audio_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith('.wav') and
                            f.startswith(tuple([str(f) for f in indices]))]
        ann_fnames = [f for f in os.listdir(root) if f.endswith('.txt') and f.startswith(tuple([str(f) for f in indices]))]
        self.annotations = []
        for file_name in ann_fnames:
            ann = []
            if file_name.split('_')[0].isdigit():
                with open(os.path.join(root, file_name), 'r') as fid:
                    for line in fid:
                        ann.append([float(f) for f in line.split('\t')])
                self.annotations.append(np.array(ann))

    def __getitem__(self, index):
        audio_file = self.audio_files[index]
        ann = self.annotations[index]
        spectorgrams = split_according_to_cycle(audio_file, ann)
        spectograms, labels = split_spectrodrams(spectorgrams, ann[:, 2:])
        return spectograms, labels
    def __len__(self):
        return len(self.audio_files)


def split_according_to_cycle(audio_file,ann):
    waveform, sample_rate = torchaudio.load(audio_file)
    channel = 0
    transformed = torchaudio.transforms.Resample(sample_rate, 16000)(waveform[channel, :].view(1, -1))
    base_len = 5*16000
    spectorgrams = []
    for cycle in ann:
        rasp_cycle = transformed[:,floor(cycle[0]*16000):ceil(cycle[1]*16000)]
        while rasp_cycle.shape[1] < base_len:
            rasp_cycle = torch.cat([rasp_cycle, rasp_cycle] ,1)

        spectorgrams.append( torchaudio.transforms.Spectrogram()(rasp_cycle))
    return spectorgrams

def split_spectrodrams(spectrograms, ann):
    split_spec = []
    tiled_lables = []

    for s, an in zip(spectrograms, ann):
        a = torch.stack([F.interpolate(a.unsqueeze(1), size=(64, 128), mode='bicubic') for a in
                         torch.split(s, 128, 2)])
        if an[0]>0:
            tiled_lables.append(torch.tensor([1]*a.shape[0]))
        elif an[1]>0:
            tiled_lables.append(torch.tensor([2]*a.shape[0]))
        else:
            tiled_lables.append(torch.tensor([0] * a.shape[0]))

        split_spec.append(a)
    return torch.cat(split_spec,dim=0).squeeze(1), torch.cat(tiled_lables)

my model and data loader initialization code is:

trainset = AudioDataset(root, train_indices)
validationset = AudioDataset(root, eval_indices)

trainloader = DataLoader(dataset=trainset,
                      batch_size=1,
                      shuffle=True,
                      collate_fn=collate_fn, # use custom collate function here
                      pin_memory=True,
                         num_workers=0)

validationloader = DataLoader(dataset=validationset,
                      batch_size=1,
                      shuffle=False,
                      collate_fn=collate_fn, # use custom collate function here
                      pin_memory=True,
                              num_workers=0)

logger.info(f'Building Model')
net = Model()
net = net.to("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
scheduler = utils.LinearWarmupScheduler(optimizer, 10, lr_sched.CosineAnnealingLR(optimizer, total_epoch))
criterion = nn.CrossEntropyLoss()

and my training code is:

    net.train()
    train_loss = 0
    total = 0
    correct = 0
    optimizer.zero_grad()
    for batch_idx, (inputs, targets,indices) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = inputs, targets
        optimizer.zero_grad()
        outputs = net(inputs, indices)
        loss = criterion(outputs.squeeze(1), targets)
        loss.backward()
        utils.clip_gradient(optimizer, 0.1)
        # if (batch_idx + 1) % 10 == 0:
            # every 10 iterations of batches of size 10
        optimizer.step()
        train_loss += loss.data
        _, predicted = torch.max(outputs.data, 1)
        total += targets.data.shape[0]
        correct += predicted.eq(targets.data).cpu().sum()
        utils.progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                       % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))


after a couple of batches my training procedure is stuck , i don’t get any errors it just stuck anyone has any idea why?

2 Likes

If num_workers=0 and i interrupt the procedure the error i get:

Epoch: 0
^CTraceback (most recent call last):
  File "train.py", line 147, in <module>
    train(epoch)
  File "train.py", line 74, in train
    for batch_idx, (inputs, targets,indices) in enumerate(trainloader):
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/gal/mixture-of-experts/datasets/ichbi.py", line 39, in __getitem__
    spectorgrams = split_according_to_cycle(audio_file, ann)
  File "/home/gal/mixture-of-experts/datasets/ichbi.py", line 55, in split_according_to_cycle
    rasp_cycle = torch.cat([rasp_cycle, rasp_cycle] ,1)
KeyboardInterrupt

if the num_workers=1 i get:

^CTraceback (most recent call last):
  File "train.py", line 147, in <module>
    train(epoch)
  File "train.py", line 74, in train
    for batch_idx, (inputs, targets,indices) in enumerate(trainloader):
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 841, in _next_data
    idx, data = self._get_data()
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 798, in _get_data
    success, data = self._try_get_data()
  File "/home/gal/.conda/envs/audio/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 761, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/gal/.conda/envs/audio/lib/python3.6/queue.py", line 173, in get
    self.not_empty.wait(remaining)
  File "/home/gal/.conda/envs/audio/lib/python3.6/threading.py", line 299, in wait
    gotit = waiter.acquire(True, timeout)
KeyboardInterrupt

my guess is that there is a data lock or something, how can i fix this issue?

1 Like

Hi @galsk87 , it seems like i have a similar problem, explained here in detail.
I want to know if you have solved your problem or not? If yes, could you please share your solution?

1 Like