I would like to classify 1d binary array (340 ) to 8 classes (each is list hot vector of length 8), I want to add conditional random field at the top my model. I have 256 batch size . the output should be list of length 8, each is probability for each class. I apply it as the following code . But I got an error from mthat line loss = - self.crf(out, labels)
. The error is “emission must have dimension of 3 and got 2”. PS: out and labels are both [256,8]. Is there something wrong with my code because it is my first time to use conditional random field layer.
class MM(nn.Module):
def __init__(self, M=1):
super(MM, self).__init__()
self.numb_label = 8
#input layer
self.layer1 = nn.Sequential(
nn.Conv1d(340, 256, kernel_size=1, stride=1, padding=0),
nn.ReLU())
self.layer2 = nn.Sequential(
nn.Conv1d(256, 128, kernel_size=1, stride=1, padding=0),
nn.ReLU())
self.layer3 = nn.Sequential(
nn.Conv1d(128, 64, kernel_size=1, stride=1, padding=0),
nn.ReLU())
self.drop1 = nn.Sequential(nn.Dropout())
self.batch1 = nn.BatchNorm1d(64)
# LSTM
#self.multihead_attn1 = nn.MultiheadAttention(680, num_heads=2)
self.lstm1=nn.Sequential(nn.LSTM(
input_size=64,
hidden_size=32,
num_layers=1,
bidirectional=True,
batch_first= True))
self.multihead_attn2 = nn.MultiheadAttention(32 * 2, num_heads=2)
self.gru = nn.Sequential(nn.GRU(
input_size=32*2,
hidden_size=16,
num_layers=2,
bidirectional=True,
batch_first=True))
self.multihead_attn3 = nn.MultiheadAttention(16*2, num_heads=2)
self.fc1 = nn.Linear(16 * 2, 8)
self.sot = nn.LogSoftmax(dims=-1)
self.crf = CRF(self.numb_label, batch_first=True)
def forward(self, x,labels= None):
#print(x.shape)
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.drop1(out)
out = self.batch1(out)
out = out.squeeze()
out = out.unsqueeze(0)
out,_ = self.lstm1(out)
out, _ = self.multihead_attn2(out, out, out)
out,_ = self.gru(out)
out = out.squeeze()
out = self.fc1(out)
out = self.soft(out)
loss = -self.crf(out, labels)
output = loss
print(out)
pred_list = self.crf.decode(out)
if output is None:
return out
else:
return (output,pred_list)