CrossEntropyLoss Multitarget error

Hi, I wanted to reproduce the network from this paper (Time delay neural network for speaker embeddings) in pytorch.

The biggest struggle to do so was implementing the stats pooling layer (where the mean and variance over the consecutive frames get calculated). After this layer I go from a 3D to 2D tensor.

class StatsPool(nn.Module):
    def __init__(self):
        """network implementation from:  https://www.danielpovey.com/files/2018_icassp_xvectors.pdf
        """
        super(StatsPool, self).__init__()

    def forward(self, x):  # x.size() = [batch_size, frames, 1500 ]
        assert len(x.size()) == 3, "Should be a 3D tensor"
        batch_size, sequence_length, input_dim = x.size()
        xs = Variable(torch.zeros((batch_size, 1, 2*input_dim), requires_grad=True)) #.cuda()
        xs[:, 0, :input_dim] = x.mean(dim=1)
        xs[:, 0, input_dim:] = x.std(dim=1)
        return xs          # xs.size() = [batch_size, 1, 3000] 

The total network layout is:

>>> print(model)
TDNN_xvector(
  (tdnn_1): Sequential(
    (0): TDNN()
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn_2): Sequential(
    (0): TDNN()
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn_3): Sequential(
    (0): TDNN()
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn_4): Sequential(
    (0): TDNN()
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn_5): Sequential(
    (0): TDNN()
    (1): ReLU()
    (2): BatchNorm1d(1500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (stats_pool): Sequential(
    (0): StatsPool()
  )
  (FC6): Sequential(
    (0): Linear(in_features=3000, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (FC7): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (softmax): Sequential(
    (0): Linear(in_features=512, out_features=4, bias=True)
    (1): Softmax()
  )
)

The forward function of the network is as follows:

class TDNN_xvector(nn.Module):
    def __init__(self, F=24, L=4):
        """network implementation from:  https://www.danielpovey.com/files/2018_icassp_xvectors.pdf
        :param int F: each frame consists of F-dimensional features
        :param int L: the output dimension (in our case, the # of training langauges)
        """
        super(TDNN_xvector, self).__init__()
        ... 

    def forward(self, x):                     # x.size() = [10, 100, 24]
        x = self.tdnn_1(x).permute(0, 2, 1)   # x.size()=  [10, 96, 512]
        x = self.tdnn_2(x).permute(0, 2, 1)   # x.size()=  [10, 92, 512]
        x = self.tdnn_3(x).permute(0, 2, 1)   # x.size()=  [10, 86, 512]
        x = self.tdnn_4(x).permute(0, 2, 1)   # x.size()=  [10, 86, 512]
        x = self.tdnn_5(x).permute(0, 2, 1)   # x.size()=  [10, 86, 1500]
        
        x = self.stats_pool(x)                # x.size() = [10, 1, 3000]
        x = x.squeeze()                       # x.size() =  [10, 3000] 
        
        x = self.FC6(x)                       # x.size() =  [10, 512] 
        x = self.FC7(x)                       # x.size() =  [10, 512] 
        x = self.softmax(x)                   # x.size() =  [10, 4]   (note: 4 = L (amount of languages
        return x

The hyperparameters of the network are:

F = 24  # amount of features
L = 4   # amount of langauges
T = 100 # length of segment (can have an arbitrary size)
BATCH_SIZE = 10

model = TDNN_xvector(F, L)
criterion = nn.CrossEntropyLoss()           # same loss function as the paper
optimizer = optim.Adam(model.parameters())

This is my train function:

model = model.cuda()

for i in range(n_epochs):
    batch = torch.rand((BATCH_SIZE, T, F)).cuda()
    labels = torch.zeros((BATCH_SIZE, L)).long().squeeze_().cuda()
    labels[:, i] = 1
    
    outputs = model(batch).cuda()
    loss = criterion(outputs, labels)
    
    loss.backward()
    optimizer.step()
    print(i, loss.item())

Which finally leads to this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-e0e465815b38> in <module>
      9 
     10     outputs = model(batch).cuda()
---> 11     loss = criterion(outputs, labels)
     12 
     13     loss.backward()

~/venv/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~/venv/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    902     def forward(self, input, target):
    903         return F.cross_entropy(input, target, weight=self.weight,
--> 904                                ignore_index=self.ignore_index, reduction=self.reduction)
    905 
    906 

~/venv/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   1968     if size_average is not None or reduce is not None:
   1969         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1970     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   1971 
   1972 

~/venv/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1788                          .format(input.size(0), target.size(0)))
   1789     if dim == 2:
-> 1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1791     elif dim == 4:
   1792         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

I think this has something to do with the stats pooling layer and dimension reduction (train function works works if i just use the layers b4 the stats pooling).

The error comes from the use of CrossEntropyLoss. It expects a LongTensor of the shape BATCH_SIZE having the indices of targets, take a look at the documentation.

1 Like