I would like to classify 1d binary array (340 ) to 8 classes (each is list hot vector of length 8), I want to add conditional random field at the top my model. I have 256 batch size . the output should be list of length 8, each is probability for each class. I apply it as the following code . But I got an error from mthat line
loss = - self.crf(out, labels). The error is “emission must have dimension of 3 and got 2”. PS: out and labels are both [256,8]. Is there something wrong with my code because it is my first time to use conditional random field layer.
class MM(nn.Module): def __init__(self, M=1): super(MM, self).__init__() self.numb_label = 8 #input layer self.layer1 = nn.Sequential( nn.Conv1d(340, 256, kernel_size=1, stride=1, padding=0), nn.ReLU()) self.layer2 = nn.Sequential( nn.Conv1d(256, 128, kernel_size=1, stride=1, padding=0), nn.ReLU()) self.layer3 = nn.Sequential( nn.Conv1d(128, 64, kernel_size=1, stride=1, padding=0), nn.ReLU()) self.drop1 = nn.Sequential(nn.Dropout()) self.batch1 = nn.BatchNorm1d(64) # LSTM #self.multihead_attn1 = nn.MultiheadAttention(680, num_heads=2) self.lstm1=nn.Sequential(nn.LSTM( input_size=64, hidden_size=32, num_layers=1, bidirectional=True, batch_first= True)) self.multihead_attn2 = nn.MultiheadAttention(32 * 2, num_heads=2) self.gru = nn.Sequential(nn.GRU( input_size=32*2, hidden_size=16, num_layers=2, bidirectional=True, batch_first=True)) self.multihead_attn3 = nn.MultiheadAttention(16*2, num_heads=2) self.fc1 = nn.Linear(16 * 2, 8) self.sot = nn.LogSoftmax(dims=-1) self.crf = CRF(self.numb_label, batch_first=True) def forward(self, x,labels= None): #print(x.shape) out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.drop1(out) out = self.batch1(out) out = out.squeeze() out = out.unsqueeze(0) out,_ = self.lstm1(out) out, _ = self.multihead_attn2(out, out, out) out,_ = self.gru(out) out = out.squeeze() out = self.fc1(out) out = self.soft(out) loss = -self.crf(out, labels) output = loss print(out) pred_list = self.crf.decode(out) if output is None: return out else: return (output,pred_list)