Hello,
I’ve seen some posts about this topic, but I my lack of knowledge with regards to lua is providing me some difficulty with converting a segment of code to work with PyTorch. I have also never worked with PyTorch. Honestly this whole task is entirely uncharted waters, so answers can be literally anything anyone might think is helpful.
I have pasted the portion of code. I am trying to pull this CNN encoder from a model from a paper to use with another model. The paper uses a pretrained google word2vec model (GoogleNews-vectors-negative300) which you can see is loaded into a LookupTable. This segment of code is drawn from this: oposum/hierMIL.lua at master · stangelid/oposum · GitHub but I only need the sentenceEncoder part. I’m a research undergrad and feel a bit overwhelmed, so any help would be great. Thank you.
require 'nn'
require 'nngraph'
require 'dpnn'
require 'rnn'
require 'model.Unsqueeze_nc'
require 'model.MixtureTableVarLen'
local ModelBuilder = torch.class('ModelBuilder')
function ModelBuilder:__init(w2v)
self.w2v = w2v
end
function ModelBuilder:getSentenceEncoder()
if opt.cudnn == 1 then
require 'cudnn'
require 'cunn'
end
local lookup = nn.LookupTable(3000000, 300)
if opt.model_type ~= 'rand' then
lookup.weight:copy(self.w2v)
else
lookup.weight:uniform(-0.25, 0.25)
end
lookup.weight[1]:zero()
local kernels = {3,4,5}
local kconcat = nn.ConcatTable()
for i = 1, #kernels do
local conv
if opt.cudnn == 1 then
conv = cudnn.SpatialConvolution(1, 100, 300, kernels[i])
conv.weight:uniform(-0.01, 0.01)
conv.bias:zero()
local single_conv = nn.Sequential()
single_conv:add(conv)
if opt.bn == 1 then
single_conv:add(nn.SpatialBatchNormalization(100))
end
single_conv:add(nn.Squeeze(3,3))
single_conv:add(cudnn.ReLU(true))
single_conv:add(nn.Max(2,2))
single_conv:add(nn.Unsqueeze(1,1))
kconcat:add(single_conv)
end
end
local sent_conv = nn.Sequential()
sent_conv:add(lookup)
sent_conv:add(nn.Unsqueeze(1,2))
sent_conv:add(kconcat)
sent_conv:add(nn.JoinTable(3))
local par = nn.Parallel(2, 2)
par:add(sent_conv)
for i=2,some_number do
local cloned_conv = sent_conv:sharedClone()
par:add(cloned_conv)
end
return par
end