Hi all,
I’m a beginner of pytorch.
I have been struggling with extracting features in forward function.
import torch
import torch.nn as nn
class CrystalGraphConvNet(nn.Module):
def __init__(self, ...)
super(CrystalGraphConvNet, self).__init__()
self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
self.convs = nn.ModuleList([ConvLayer()])
self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
self.conv_to_fc_softplus = nn.Softplus()
self.fc_out = nn.Linear(h_fea_len, 1)
def forward(self, ...):
atom_fea = self.embedding(atom_fea)
for conv_func in self.convs:
atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
crys_fea = self.pooling(atom_fea, crystal_atom_idx)
crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
out = self.fc_out(crys_fea)
return out
from skorch import NeuralNetRegressor
from model import CrystalGraphConvNet
net = NeuralNetRegressor(
CrystalGraphConvNet,
modules...)
net.initialize()
net.fit(SDT_training,target_training)
I want to extract values of features in the forward function, such as crys_fea
.
Please help me out. Thank you!