class TimeAutoEncoder(nn.Module):
def __init__(self):
super(TimeAutoEncoder, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels = 48, out_channels = 1876, kernel_size = 3, stride = 1, padding = 0, dilation = 1),
nn.BatchNorm1d(1876),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(in_channels = 1876, out_channels = 938, kernel_size = 3, stride = 1, padding = 0, dilation = 2),
nn.BatchNorm1d(938),
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.Conv1d(in_channels = 938, out_channels = 512, kernel_size = 3, stride = 1, padding = 0, dilation = 4),
nn.BatchNorm1d(512),
nn.ReLU(),
)
self.conv4 = nn.Sequential(
nn.Conv1d(in_channels = 512, out_channels = 256, kernel_size = 3, stride = 1, padding = 0, dilation = 8),
nn.BatchNorm1d(256),
nn.ReLU(),
)
self.conv5 = nn.Sequential(
nn.Conv1d(in_channels = 256, out_channels = 128, kernel_size = 3, stride = 1, padding = 0, dilation = 16),
nn.BatchNorm1d(128),
nn.ReLU(),
)
self.conv6 = nn.Sequential(
nn.Conv1d(in_channels = 128, out_channels = 64, kernel_size = 3, stride = 1, padding = 0, dilation = 32),
nn.BatchNorm1d(64),
nn.ReLU(),
)
self.conv7 = nn.Sequential(
nn.Conv1d(in_channels = 64, out_channels = 32, kernel_size = 3, stride = 1, padding = 0, dilation = 64),
nn.BatchNorm1d(32),
nn.ReLU(),
)
self.conv8 = nn.Sequential(
nn.Conv1d(in_channels = 32, out_channels = 16, kernel_size = 3, stride = 1, padding = 0, dilation = 128),
nn.BatchNorm1d(16),
nn.ReLU(),
)
self.conv9 = nn.Sequential(
nn.Conv1d(in_channels = 16, out_channels = 8, kernel_size = 3, stride = 1, padding = 0, dilation = 256),
nn.BatchNorm1d(8),
nn.ReLU(),
)
self.encoder_fc = nn.Sequential(
nn.Linear(8 * 1876, 128),
nn.BatchNorm1d(128),
nn.Tanh(),
)
self.decoder_fc = nn.Sequential(
nn.Linear(128, 8 * 1876),
nn.ReLU(),
)
self.t_conv1 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 8, out_channels = 16, kernel_size = 3, stride = 1, dilation=62),
nn.Conv1d(in_channels = 8, out_channels = 16, kernel_size = 3, stride = 1, padding = 0, dilation = 256),
nn.BatchNorm1d(16),
nn.ReLU(),
)
self.t_conv2 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 16, out_channels = 32, kernel_size = 3, stride = 1, dilation = 30),
nn.Conv1d(in_channels = 16, out_channels = 32, kernel_size = 3, stride = 1, padding = 0, dilation = 128),
nn.BatchNorm1d(32),
nn.ReLU(),
)
self.t_conv3 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, dilation=14),
nn.Conv1d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 0, dilation = 64),
nn.BatchNorm1d(64),
nn.ReLU(),
)
self.t_conv4 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, dilation = 6),
nn.Conv1d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 0, dilation = 32),
nn.BatchNorm1d(128),
nn.ReLU(),
)
self.t_conv5 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, dilation=2),
nn.Conv1d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 0, dilation = 16),
nn.BatchNorm1d(256),
nn.ReLU(),
)
self.t_conv6 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, dilation = 1),
nn.Conv1d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 0, dilation = 8),
nn.BatchNorm1d(512),
nn.ReLU(),
)
self.t_conv7 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, dilation = 1),
nn.Conv1d(in_channels = 512, out_channels = 938, kernel_size = 3, stride = 1, padding = 0, dilation = 4),
nn.BatchNorm1d(938),
nn.ReLU(),
)
self.t_conv8 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, dilation = 1),
nn.Conv1d(in_channels = 938, out_channels = 1876, kernel_size = 3, stride = 1, padding = 0, dilation = 2),
nn.BatchNorm1d(1876),
nn.ReLU(),
)
self.t_conv9 = nn.Sequential(
# nn.ConvTranspose1d(in_channels = 512, out_channels = 48, kernel_size = 3, stride = 1, dilation= 1),
nn.Conv1d(in_channels = 1876, out_channels = 48, kernel_size = 3, stride = 1, padding = 0, dilation = 1)
)
def forward(self, mel_spec):
x = F.pad(mel_spec, pad = (2, 0, 0, 0))
x = self.conv1(x)
# print(x.shape)
x = F.pad(x, pad = (4, 0, 0, 0))
x = self.conv2(x)
# print(x.shape)
x = F.pad(x, pad = (8, 0, 0, 0))
x = self.conv3(x)
# print(x.shape)
x = F.pad(x, pad = (16, 0, 0, 0))
x = self.conv4(x)
# print(x.shape)
x = F.pad(x, pad = (32, 0, 0, 0))
x = self.conv5(x)
# print(x.shape)
x = F.pad(x, pad = (64, 0, 0, 0))
x = self.conv6(x)
# print(x.shape)
x = F.pad(x, pad = (128, 0, 0, 0))
x = self.conv7(x)
x = F.pad(x, pad = (256, 0, 0, 0))
x = self.conv8(x)
x = F.pad(x, pad = (512, 0, 0, 0))
x = self.conv9(x)
# print(x.shape)
encode = self.encoder_fc(x.view(-1, 8 * 1876))
# print('decode')
x = self.decoder_fc(encode)
x = x.view(-1, 8, 1876)
x = torch.swapaxes(torch.fliplr(torch.swapaxes(x, 1, 2)), 1, 2)
x = F.pad(x, pad = (512, 0, 0, 0))
x = self.t_conv1(x)
x = F.pad(x, pad = (256, 0, 0, 0))
x = self.t_conv2(x)
x = F.pad(x, pad = (128, 0, 0, 0))
x = self.t_conv3(x)
# print(x.shape)
x = F.pad(x, pad = (64, 0, 0, 0))
x = self.t_conv4(x)
# print(x.shape)
x = F.pad(x, pad = (32, 0, 0, 0))
x = self.t_conv5(x)
# print(x.shape)
x = F.pad(x, pad = (16, 0, 0, 0))
x = self.t_conv6(x)
# print(x.shape)
x = F.pad(x, pad = (8, 0, 0, 0))
x = self.t_conv7(x)
# print(x.shape)
x = F.pad(x, pad = (4, 0, 0, 0))
x = self.t_conv8(x)
# print(x.shape)
x = F.pad(x, pad = (2, 0, 0, 0))
x = self.t_conv9(x)
# print(x.shape)
x = torch.swapaxes(torch.fliplr(torch.swapaxes(x, 1, 2)), 1, 2)
return encode, x
import time
min_loss = 987654321
for epoch in range(1, epochs + 1):
start = time.time()
train_loss = train(model = model, train_loader = train_batch_li)
val_loss = val(model = model, train_loader = val_batch_li)
end = time.time()
print(f'EPOCH:{epoch}, Train Loss:{train_loss}, Val Loss:{val_loss}, 학습 시간: {end - start}')
if val_loss < min_loss:
min_loss = val_loss
torch.save(model.state_dict(), model_dir + f'TimeAutoEncoder_val.pt')
print('모델 저장')
After running the above code, the error is like the following. running_mean should contain 48 elements not 1876. What can I do ?