import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from torchsummary import summary
X1 = torch.randn(300, 1, 256, 256)
X2 = torch.randn(300, 10)
y= torch.randn(300, 10)
class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.features1 = nn.Sequential(
nn.Conv2d(1, 3, 3, 1, 1),
nn.MaxPool2d(2),
nn.ReLU(),
)
self.features2 = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
)
self.classifier = nn.Linear(128 * 128 * 3 + 5, 4)
def forward(self, x1, x2):
x1 = self.features1(x1)
x2 = self.features2(x2)
x1 = x1.view(x1.size(0), -1)
x2 = x2.view(x2.size(0), -1)
x = torch.cat((x1, x2), dim=1)
x = self.classifier(x)
return x
class MultiTaskDataset:
def init(self, image_input, non_image_input, target):
self.image_input = image_input
self.non_image_input = non_image_input
self.target = target
def __len__(self):
return self.image_input.shape[-1]
def __getitem__(self, idx):
image_input = self.image_input[idx]
non_image_input = self.non_image_input[idx]
target = self.target[idx]
return ([image_input,
non_image_input],
target)
#return ([torch.tensor(image_input, dtype=torch.float32),
# torch.tensor(non_image_input, dtype=torch.float32)],
# torch.tensor(target, dtype=torch.long))
model = MyModel()
print(model)
ds = MultiTaskDataset(X1, X2, y)
train_loader = DataLoader(ds, batch_size=1, shuffle=True)
first_batch = next(iter(train_loader))
def train(epoch):
model.train()
#exp_lr_scheduler.step()
print(“epoch”,epoch)
for batch_idx, (data, target) in enumerate(train_loader):
amp, phase = data
print("amp shape, phase.shape")
print(amp.shape, phase.shape)
if torch.cuda.is_available():
amp = amp.cuda()
phase = phase.cuda()
target = target.cuda()
optimizer.zero_grad()
output = model(amp, phase)
loss = criterion(output, target.long())
loss.backward()
optimizer.step()
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.data))
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()
n_epochs = 5
for epoch in range(n_epochs):
train(epoch)
torch.save(model, ‘model.pth’)
model = torch.load(“model.pth”)
print(model)