This is my model and training process. My accuracy seems same after every epoch. I think its re-initializing the weights every time. how?
import torch.nn as nn
import torch.nn.functional as F
class TDNN(nn.Module):
def __init__(
self,
input_dim=23,
output_dim=512,
context_size=5,
stride=1,
dilation=1,
batch_norm=False,
dropout_p=0.2
):
'''
TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf
Affine transformation not applied globally to all frames but smaller windows with local context
batch_norm: True to include batch normalisation after the non linearity
Context size and dilation determine the frames selected
(although context size is not really defined in the traditional sense)
For example:
context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2]
context size 3 and dilation 2 is equivalent to [-2, 0, 2]
context size 1 and dilation 1 is equivalent to [0]
'''
super(TDNN, self).__init__()
self.context_size = context_size
self.stride = stride
self.input_dim = input_dim
self.output_dim = output_dim
self.dilation = dilation
self.dropout_p = dropout_p
self.batch_norm = batch_norm
self.kernel = nn.Linear(input_dim*context_size, output_dim)
self.nonlinearity = nn.ReLU()
if self.batch_norm:
self.bn = nn.BatchNorm1d(output_dim)
if self.dropout_p:
self.drop = nn.Dropout(p=self.dropout_p)
def forward(self, x):
'''
input: size (batch, seq_len, input_features)
outpu: size (batch, new_seq_len, output_features)
'''
_, _, d = x.shape
assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
x = x.unsqueeze(1)
# Unfold input into smaller temporal contexts
x = F.unfold(
x,
(self.context_size, self.input_dim),
stride=(1,self.input_dim),
dilation=(self.dilation,1)
)
# N, output_dim*context_size, new_t = x.shape
x = x.transpose(1,2)
x = self.kernel(x.float())
x = self.nonlinearity(x)
if self.dropout_p:
x = self.drop(x)
if self.batch_norm:
x = x.transpose(1,2)
x = self.bn(x)
x = x.transpose(1,2)
return x
import torch.nn as nn
# from models.tdnn import TDNN
import torch
import torch.nn.functional as F
class X_vector(nn.Module):
def __init__(self, input_dim = 24, num_classes=7):
super(X_vector, self).__init__()
self.tdnn1 = TDNN(input_dim=input_dim, output_dim=512, context_size=5, dilation=1,dropout_p=0.5)
self.tdnn2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=1,dropout_p=0.5)
self.tdnn3 = TDNN(input_dim=512, output_dim=512, context_size=2, dilation=2,dropout_p=0.5)
self.tdnn4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1,dropout_p=0.5)
self.tdnn5 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=3,dropout_p=0.5)
#### Frame levelPooling
self.segment6 = nn.Linear(1024, 512)
self.segment7 = nn.Linear(512, 512)
self.output = nn.Linear(512, num_classes)
# self.softmax = nn.Softmax(dim=1)
def forward(self, inputs):
tdnn1_out = self.tdnn1(inputs)
# return tdnn1_out
tdnn2_out = self.tdnn2(tdnn1_out)
tdnn3_out = self.tdnn3(tdnn2_out)
tdnn4_out = self.tdnn4(tdnn3_out)
tdnn5_out = self.tdnn5(tdnn4_out)
### Stat Pool
mean = torch.mean(tdnn5_out,1)
std = torch.std(tdnn5_out,1)
stat_pooling = torch.cat((mean,std),1)
segment6_out = self.segment6(stat_pooling)
x_vec = self.segment7(segment6_out)
predictions = self.output(x_vec)
return predictions
def train_valid_em(train_set,valid_set):
optimizer = optim.Adam(emrec.parameters(), lr = 0.0001)
criterion = nn.CrossEntropyLoss()
epochs = 10
min_validloss = np.inf
for epoch in range(epochs):
print(epoch)
train_loss=0
for data in train_set:
X, y = data
emrec.zero_grad()
output = emrec(reshapex(X))
y=reshapey(y)
loss=criterion(output, torch.argmax(y).view(-1))
loss.backward()
optimizer.step()
train_loss+=loss.item()
valid_loss = 0
for data in valid_set:
X, y = data
output = emrec(reshapex(X))
y=reshapey(y)
loss=criterion(output, torch.argmax(y).view(-1))
valid_loss+=loss.item()
if min_validloss>valid_loss:
min_validloss = valid_loss
torch.save(emrec.state_dict(), 'saved_model.pth')
corr = 0
total = 0
with torch.no_grad():
for data in test_set:
X,y=data
output = emrec(reshapex(X))
for idx,i in enumerate(output):
if torch.argmax(i)==torch.argmax(y):
corr+=1
total+=1
print(corr*100/total)
return corr*100/total