I’m trying to classify text in 128 classes.
OVERVIEW is text, and CAT3 is labels.
from torchtext.data import TabularDataset
train, validation = TabularDataset.splits(
path = location, train = ‘train_jh.csv’, validation = ‘validation_jh.csv’, format=‘csv’,
fields=[(‘overview’, OVERVIEW), (‘cat3’, CAT3)], skip_header=True)
import torch
from torchtext.vocab import Vectors
from torchtext.data import BucketIterator
vectors = Vectors(name = ‘/open/overview_tokens_w2v’)
vectors_cat3 = Vectors(name = ‘/open/cat3_tokens_w2v’)
OVERVIEW.build_vocab(train,
vectors = vectors, min_freq = 1, max_size = None)
CAT3.build_vocab(train,
vectors = vectors_cat3, min_freq = 1, max_size = None)
vocab = OVERVIEW.vocab
device = torch.device(‘mps:0’ if torch.backends.mps.is_available() else ‘cpu’)
train_iter, validation_iter = BucketIterator.splits(
datasets = (train, validation),
batch_size = 10,
device = device,
sort = False
)
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class TextCNN(nn.Module):
def init(self, vocab_built, emb_dim, dim_channel, kernel_wins, num_class):
super(TextCNN, self).__init__()
self.embed = nn.Embedding(len(vocab_built), emb_dim)
self.embed.weight.data.copy_(vocab_built.vectors)
self.convs = nn.ModuleList([nn.Conv2d(1, dim_channel, (w, emb_dim))
for w in kernel_wins])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
self.fc = nn.Linear(len(kernel_wins)*dim_channel, num_class)
def forward(self, x):
emb_x = self.embed(x)
emb_x = emb_x.unsqueeze(1)
con_x = [self.relu(conv(emb_x)) for conv in self.convs]
pool_x = [F.max_pool1d(x.squeeze(-1), x.size()[2])
for x in con_x]
fc_x = torch.cat(pool_x, dim=1)
fc_x = fc_x.squeeze(-1)
fc_x = self.dropout(fc_x)
logit = self.fc(fc_x)
return logit
def train(model, device, train_itr, optimizer):
model.train()
corrects, train_loss = 0.0, 0
for batch in train_itr:
text, target = batch.overview, batch.cat3
text = torch.transpose(text, 0 ,1)
target.data.sub_(1)
text, target = text.to(device), target.to(device)
optimizer.zero_grad()
logit = model(text)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
result = torch.max(logit,1)[1]
corrects += (result.view(target.size()).data == target.data).sum()
train_loss /= len(train_itr.dataset)
accuracy = 100.0* corrects / len(train_itr.dataset)
return train_loss, accuracy
def evaluate(model, device, itr):
model.eval()
corrects, test_loss = 0.0, 0
for batch in itr:
text = batch.overview
target = batch.cat3
text = torch.transpose(text, 0 ,1)
target.data.sub_(1)
text, target = text.to(device), target.to(device)
optimizer.zero_grad()
logit = model(text)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
result = torch.max(logit,1)[1]
corrects += (result.view(target.size()).data == target.data).sum()
train_loss /= len(train_itr.dataset)
accuracy = 100.0* corrects / len(train_itr.dataset)
return train_loss, accuracy
model = TextCNN(vocab, 100, 10, [3,4,5], 128).to(device)
print(model)
device = torch.device(‘mps:0’ if torch.backends.mps.is_available() else ‘cpu’)
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_test_acc = -1
for epoch in range(1, 3+1):
tr_loss, tr_acc = train(model, device, train_iter, optimizer)
print('Train Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, tr_loss, tr_acc))
val_loss, val_acc = evaluate(model, device, validation_iter)
print('Valid Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, val_loss, val_acc))
if val_acc > best_test_acc:
best_test_acc = val_acc
print('model saves at {} accuracy'.format(best_test_acc))
torch.save(model.state_dict(), 'TextCNN_Best_Validation')
print('----------------------------------------------------------------')
My error is as follows.
TypeError Traceback (most recent call last)
Input In [47], in <cell line: 10>()
8 best_test_acc = -1
10 for epoch in range(1, 3+1):
—> 12 tr_loss, tr_acc = train(model, device, train_iter, optimizer)
14 print(‘Train Epoch: {} \t Loss: {} \t Accuracy: {}%’.format(epoch, tr_loss, tr_acc))
16 val_loss, val_acc = evaluate(model, device, validation_iter)
TypeError: ‘TabularDataset’ object is not callable
Why I can’t call ‘TabularDataset’ object?
If you can, please check my codes, and give me feedbacks.
Thank you.