Build a neural composer using RNN

import torch
import torch.nn as nn
import helpers.dataset as dataset
from torch.autograd import Variable


input_size = 128
hidden_size = 128
num_layers = 2
output_size = 128
batch_size = 1
num_epochs = 2
learning_rate = 0.01

# Dataset
train_dataset = dataset.pianoroll_dataset_batch('./datasets/training/piano_roll_fs1')  # pianoroll_dataset_batch instance
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes=128, n_layers=num_layers):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.n_layers = n_layers
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=n_layers, batch_first=False)

    def forward(self, input_sequence, hidden):

        # Output shape=(seq_length, batch_size, hidden_size)
        output, hidden = self.rnn(input_sequence, hidden)
        # output = output.reshape(-1, self.num_classes)         #Predict which classes should be pressed
        return output, hidden



model = RNN(input_size, hidden_size, num_layers)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Init hidden state with shape=(num_layers*num_directions, batch_size, hidden_size)
hidden = Variable(torch.Tensor(num_layers*1, batch_size, hidden_size))

for i, (features, _, targets) in enumerate(train_loader):
    # Seconds in each input stream
    seq_len = features.size(1)

    # The input dimensions are (seq_len, batch, input_size)
    features = Variable(features.reshape(seq_len, -1, input_size))
    targets  = Variable(targets.reshape(seq_len, hidden_size))

    # Forward pass
    outputs, hidden = model(features, hidden)
    loss = criterion(outputs, targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('i: {}, Loss: {:.4f}'.format(i+1, loss.item()))

    if i % 10 == 0:
        print(outputs)

Iā€™m trying to train a model which learns how to play piano based on some MIDI files (which is handled by the helper class which inherits from torchā€™s Dataset). Iā€™ve watched countless posts to try to debug my errors, but Iā€™m left with even more questions.

  1. Should I need to reshape the targets?
  2. I keep gettings this error (in the for-loop under ā€œForward passā€):

<class ā€˜tupleā€™>: (<class ā€˜RuntimeErrorā€™>, RuntimeError(ā€œExpected object of type torch.LongTensor but found type torch.FloatTensor for argument #2 ā€˜targetā€™ā€,), None)

Iā€™ve tried debugging and inspecting both outputs and targets show that both have the same type (torch.float32). Converting targets to .long() just produces another error:

<class ā€˜tupleā€™>: (<class ā€˜RuntimeErrorā€™>, RuntimeError(ā€œAssertion `cur_target >= 0 && cur_target < n_classesā€™ failed. at c:\programdata\miniconda3\conda-bld\pytorch_1533096106539\work\aten\src\thnn\generic/SpatialClassNLLCriterion.c:110ā€,), None)

  1. The dataset Iā€™m using has a piano with 128 keys (so 128 classes). Where do I define the output layer?
  2. When I define any other hidden_size than 128 (same as input_size) I get an error saying that the dimensions isnā€™t right, what am I missing here?
  3. The piano matrix which is fed in has binary values, how do I obtain the same format in the final layer (and not a bunch of floats) while maintaining the multi-class property (multiple keys can be pressed at the same time)? Is outputs.long() on the last layer the right way to do it, or should I use some other loss function which has this built-in?

I apologize for the beginner questions, but Iā€™m new to this and watching tutorials and reading countless of posts has only helped me so much, so Iā€™d appreciate anyone who takes the time to explain this so I can fill in the gaps.

  1. It depends on the shape of your model output. If you are using nn.CrossEntropyLoss the targets should have the same shape as the output without the channel dimension. See the docs for more information.

  2. Your target should be a torch.LongTensor containing the class indices. The second error message means, that your target contains invalid indices, i.e. smaller than 0 or bigger than number_of_classes - 1.

  3. Iā€™m not sure Iā€™m understanding the question properly.

  4. Most likely your hidden_size is fixed, since you are using it for your target tensor. The shape mismatch will be probably thrown in your criterion.

  5. Based on this use case (multiple keys can be pressed at the same same), it looks like you are dealing with a multi-label classification. A way to deal with such a task would be to use nn.BCELoss instead of nn.CrossEntropyLoss to allow more than one ā€œactiveā€ class. Have a look at the docs for more information about the shape etc.

Thank you for your feedback, @ptrblck !

  1. Iā€™ve changed targets to
    targets = targets.reshape(seq_len, input_size).long()
    So now Outputs is torch.Size([226, 1, 128]) while targets is torch.Size([226, 128]) if I understood your answer in 1. correctly. However, I keep getting the second errors abour invalid indices. Iā€™m not sure how to fix this?

How many classes do you have?
Currently your output channel is only 1, which seems wrong. You should have at least 2 channels for a binary classification.

Input: (N,C)
where C = number of classes, or
(N,C,d1,d2,ā€¦,dK) with Kā‰„2 in the case of K-dimensional loss.

Target: (N) where each value is 0ā‰¤targets[i]ā‰¤Cāˆ’1
, or
(N,d1,d2,ā€¦,dK) with Kā‰„2 in the case of K-dimensional loss.

I have 128 pianokeys that can be pressed at any time step, so 128 classes if Iā€™ve understood it correctly?

Thanks for the info. Thatā€™s a lot for a piano :wink:

It looks like your target might be one-hot encoded, since its dimension is currently [batch_size, nb_classes], which is wrong using nn.CrossentropyLoss.

However, letā€™s stick to your use case and use a multi-label approach.
For this we would need ā€œmulti-hotā€ encoded targets, i.e. for each sample a tensor of this form [0., 0., 1., 0., 1., 0., ...].
Here is a dummy example:


batch_size = 10
nb_classes = 5
output = torch.randn(batch_size, nb_classes, requires_grad=True)
target = torch.empty(batch_size, nb_classes).random_(2)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(output, target)
loss.backward()

Basically, your model output and target should have the same shape in this case.
If your output contains raw logits, use nn.BCEWithLogitsLoss.
Alternatively, you could use torch.sigmoid on the output and use nn.BCELoss as the criterion.

It sure is a lot of keys!

The targets has shape torch.Size([1, 226, 1, 128]), when returned from the train_loader in the for-loop before applying any chaning:

for i, (features, _, targets) in enumerate(train_loader):
    # Seconds in each input stream
    seq_len = features.size(1)

When I change the shape of the targets like you suggested:

for i, (features, _, targets) in enumerate(train_loader):
    # Seconds in each input stream
    seq_len = features.size(1)

    features = features.reshape(seq_len, -1, input_size)
    targets = targets.reshape(batch_size, num_classes).long()

I get an error saying that the shape is invalid for the input size

RuntimeError: shape ā€˜[1, 128]ā€™ is invalid for input of size 28928

Iā€™m sure there is something Iā€™m not seeing here, or I may have misunderstood something.

What does dim1 (226) stand for? I assumed itā€™s the batch size, bit apparently Iā€™m wrong.
Is it the sequence length?

Yes, itā€™s sequence length or seconds.

Assuming you meant seq_length instead of batch_size, this is how my code looks now:

    import torch
    import torch.nn as nn
    import helpers.dataset as dataset
    
    
    input_size = 128
    hidden_size = 128
    num_layers = 2
    output_size = 128
    batch_size = 1
    num_epochs = 2
    learning_rate = 0.01
    num_classes = 128
    
    # Dataset
    train_dataset = dataset.pianoroll_dataset_batch('./datasets/training/piano_roll_fs1')  # pianoroll_dataset_batch instance
    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False,
                                               drop_last=True)  # Drop-last
    
    
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_classes=128, n_layers=num_layers):
            super(RNN, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.num_classes = num_classes
            self.n_layers = n_layers
            self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=n_layers, batch_first=False)
    
        def forward(self, input_sequence, hidden):
    
            output, hidden = self.rnn(input_sequence, hidden)
    
            # Output reshaped to shape=(seq_length, num_classes)
            return output.reshape(seq_len, num_classes), hidden
    
        def init_hidden(self):
            pass
    
    
    model = RNN(input_size, hidden_size, num_layers)
    # Loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Init hidden state with shape=(num_layers*num_directions, batch_size, hidden_size)
    hidden = torch.Tensor(num_layers * 1, batch_size, hidden_size)
    for i, (features, _, targets) in enumerate(train_loader):
        # Seconds in each input stream
        seq_len = features.size(1)
    
        # The input dimensions are reshaped to (seq_len, batch, input_size)
        features = features.reshape(seq_len, -1, input_size)
        targets = targets.reshape(seq_len, num_classes)
    
        # Forward pass
        outputs, hidden = model(features, hidden)  # Runs forward()
        loss = criterion(outputs, targets)
    
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
    
        print('i: {}, Loss: {:.4f}'.format(i + 1, loss.item()))
    
        if i % 10 == 0:
            print(outputs)

I needed to add retain_graph=True in loss.backward().

In that case, your shapes should be [batch_size, seq_length, nb_classes].
Try to permute your ourput using output = output.permute(1, 0, 2) and squeeze the additional dimension in your target target = target.squeeze(2).
Where does dim2 in your target come from?

Ok, Iā€™m going to try this.

Dim2 in targets (1) is just how the data is returned from the custom Dataset class.

Ok, Iā€™ve changed the things you suggested and the code seems to be running now. But for some reason sometimes when I run the code the return value from loss.item() is nan, while other time it returns a floatā€¦? It seems to be happening at random, do you have any idea why this might be happening?

i: 1, Loss: 0.6917
i: 2, Loss: 0.5840
i: 3, Loss: 0.4336
i: 4, Loss: nan
i: 5, Loss: nan
i: 6, Loss: nan
i: 7, Loss: nan
i: 8, Loss: nan
i: 9, Loss: nan
i: 10, Loss: nan
i: 11, Loss: nan
i: 12, Loss: nan

Could you check your input for NaNs or Inf values?
torch.isnan and torch.isinf should do the job.
If thatā€™s not the case, you could try the anomaly detection util to check your model.

This is whatā€™s returned using detect_anomaly()

i: 1, Loss: 0.6978
i: 2, Loss: 0.5917
sys:1: RuntimeWarning: Traceback of forward call that caused the error:
File ā€œ/neural-composer/legacy.pyā€, line 69, in
outputs, hidden = model(features, hidden) # Runs forward()
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.pyā€, line 477, in call
result = self.forward(*input, **kwargs)
File ā€œ/neural-composer/legacy.pyā€, line 35, in forward
output, hidden = self.rnn(input_sequence, hidden)
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.pyā€, line 477, in call
result = self.forward(*input, **kwargs)
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\rnn.pyā€, line 192, in forward
output, hidden = func(input, self.all_weights, hx, batch_sizes)
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\nn_functions\rnn.pyā€, line 324, in forward
return func(input, *fargs, **fkwargs)
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\nn_functions\rnn.pyā€, line 288, in forward
dropout_ts)

Traceback (most recent call last):
File ā€œ/neural-composer/legacy.pyā€, line 74, in
loss.backward(retain_graph=True)
File ā€œ\pirar\Anaconda3\envs\pytorch\lib\site-packages\torch\tensor.pyā€, line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File ā€œ\Anaconda3\envs\pytorch\lib\site-packages\torch\autograd_init_.pyā€, line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function ā€˜CudnnRnnBackwardā€™ returned nan values in its 1th output.

The problems seems to go away when I change the learning rate from 0.01 to 0.001. Does that tell you anything?

Now, since the code is finally running: The output returned is going to be floats, but since I need a classification of either 0 or 1 in the output, what should I do to obtain this? Do I need to do some kind of normalization, or should I apply some activation function in the last output layer - if so, which one?

Yeah, a high learning rate might push some parameters to inf.

You could apply a sigmoid on the output and use a threshold to get your predictions.

Ah, I see!

Ok, Iā€™m gonna test this. Thank you for being so patient so far!

Another question: Currently Iā€™ve set hidden size to 128 (matching both input_size and num_clasees) - if I change it hidden_size=100, now the dimensions wonā€™t match, leading to this error:

Target size (torch.Size([1, 226, 128])) must be the same as input size (torch.Size([1, 226, 100]))

But since the dimensions on the outputs and target should match, how can I change the hidden_size without affecting the relationship between output/target? @ptrblck

This is the output that is generated after training is done which doesnā€™t seem right.

All the values seem to be more or less the sameā€¦

If you want to use another hidden size, you could pass the RNN output to another linear layer like in this tutorial.

Also, Iā€™ve noticed another issue, which might explain the constant output and the NaN values.
Currently you are using torch.Tensor() to initialize your hidden tensor. This will use uninitialized memory, i.e. all values will be set to whatever is stored in the memory block.
You should rather use torch.zeros or torch.randn for random values.

1 Like