Hi,
I have reduced the batch size. The error I am getting now is:
File "/home/raghad/anaconda3/envs/monet/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/raghad/anaconda3/envs/monet/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/raghad/Documents/MoNet/image/main.py", line 90, in <module>
device)
File "/home/raghad/Documents/MoNet/image/train_eval.py", line 18, in run
train_loss = train(model, train_loader, optimizer, device)
File "/home/raghad/Documents/MoNet/image/train_eval.py", line 42, in train
loss = F.nll_loss(model(data), data.y)
File "/home/raghad/anaconda3/envs/monet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/raghad/Documents/MoNet/image/main.py", line 61, in forward
data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
File "/home/raghad/anaconda3/envs/monet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/raghad/Documents/MoNet/conv/gmm_conv.py", line 47, in forward
out = self.propagate(edge_index, x=out, pseudo=pseudo)
File "/home/raghad/anaconda3/envs/monet/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 237, in propagate
out = self.message(**msg_kwargs)
File "/home/raghad/Documents/MoNet/conv/gmm_conv.py", line 61, in message
return (x_j * gaussian).sum(dim=1)
RuntimeError: The size of tensor a (22065) must match the size of tensor b (25) at non-singleton dimension 1
And my code (main.py) is:
import time
import os.path as osp
import argparse
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch_geometric.datasets import MNISTSuperpixels
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.utils import normalized_cut
from torch_geometric.nn import (graclus, max_pool, global_mean_pool)
from conv import GMMConv
from image import run
parser = argparse.ArgumentParser(description='superpixel MNIST')
parser.add_argument('--dataset', default='MNIST', type=str)
parser.add_argument('--device_idx', default=3, type=int)
parser.add_argument('--kernel_size', default=25, type=int)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=0.99)
parser.add_argument('--decay_step', type=int, default=1)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--seed', type=int, default=1)
args = parser.parse_args()
args.data_fp = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
#device = torch.device('cuda', args.device_idx)
device = torch.device('cuda', 0)
# deterministic
torch.manual_seed(args.seed)
cudnn.benchmark = False
cudnn.deterministic = True
train_dataset = MNISTSuperpixels(args.data_fp, True, pre_transform=T.Polar())
test_dataset = MNISTSuperpixels(args.data_fp, False, pre_transform=T.Polar())
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)
def normalized_cut_2d(edge_index, pos):
row, col = edge_index
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
class MoNet(torch.nn.Module):
def __init__(self, kernel_size):
super(MoNet, self).__init__()
self.conv1 = GMMConv(1, 32, dim=2, kernel_size=kernel_size)
self.conv2 = GMMConv(32, 64, dim=2, kernel_size=kernel_size)
self.conv3 = GMMConv(64, 64, dim=2, kernel_size=kernel_size)
self.fc1 = torch.nn.Linear(64, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, data):
data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
weight = normalized_cut_2d(data.edge_index, data.pos)
cluster = graclus(data.edge_index, weight, data.x.size(0))
data.edge_attr = None
data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
weight = normalized_cut_2d(data.edge_index, data.pos)
cluster = graclus(data.edge_index, weight, data.x.size(0))
data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
x = global_mean_pool(data.x, data.batch)
x = F.elu(self.fc1(x))
x = F.dropout(x, training=self.training)
return F.log_softmax(self.fc2(x), dim=1)
model = MoNet(args.kernel_size).to(device)
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
args.decay_step,
gamma=args.lr_decay)
print(model)
run(model, args.epochs, train_loader, test_loader, optimizer, scheduler,
device)
code for ‘train_eval.py’ is:
import time
import torch
import torch.nn.functional as F
def print_info(info):
message = ('Epoch: {}/{}, Duration: {:.3f}s, Train Loss: {:.4f}, '
'Test Loss: {:.4f}, Test Acc: {:.4f}').format(
info['current_epoch'], info['epochs'], info['t_duration'],
info['train_loss'], info['test_loss'], info['acc'])
print(message)
def run(model, epochs, train_loader, test_loader, optimizer, scheduler,
device):
for epoch in range(1, epochs + 1):
t = time.time()
train_loss = train(model, train_loader, optimizer, device)
t_duration = time.time() - t
scheduler.step()
acc, test_loss = test(model, test_loader, device)
info = {
'train_loss': train_loss,
'test_loss': test_loss,
'acc': acc,
'current_epoch': epoch,
'epochs': epochs,
't_duration': t_duration
}
print_info(info)
def train(model, train_loader, optimizer, device):
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
loss = F.nll_loss(model(data), data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def test(model, test_loader, device):
model.eval()
correct = 0
total_loss = 0
with torch.no_grad():
for idx, data in enumerate(test_loader):
data = data.to(device)
out = model(data)
total_loss += F.nll_loss(out, data.y).item()
pred = out.max(1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(test_loader.dataset), total_loss / len(test_loader)
And code for gmm_conv.py is:
import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from .inits import reset, glorot, zeros
EPS = 1e-15
class GMMConv(MessagePassing):
def __init__(self,
in_channels,
out_channels,
dim,
kernel_size,
bias=True,
**kwargs):
super(GMMConv, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.dim = dim
self.kernel_size = kernel_size
self.lin = torch.nn.Linear(in_channels,
out_channels * kernel_size,
bias=False)
self.mu = Parameter(torch.Tensor(kernel_size, dim))
self.sigma = Parameter(torch.Tensor(kernel_size, dim))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.mu)
glorot(self.sigma)
zeros(self.bias)
reset(self.lin)
def forward(self, x, edge_index, pseudo):
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
out = self.lin(x).view(-1, self.kernel_size, self.out_channels)
out = self.propagate(edge_index, x=out, pseudo=pseudo)
print(out.size())
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j, pseudo):
(E, D), K = pseudo.size(), self.mu.size(0)
gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, K, D))**2
gaussian = gaussian / (EPS + self.sigma.view(1, K, D)**2)
gaussian = torch.exp(gaussian.sum(dim=-1, keepdim=True)) # [E, K, 1]
return (x_j * gaussian).sum(dim=1)
def __repr__(self):
return '{}({}, {}, kernel_size={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels,
self.kernel_size)