@tom Thank you so much. I haven’t fully understand your answer yet. The link is nice, however, I don’t know how to apply it into for-loops
, i.e. iterate all possible task_id
as input. Would you please provide some example?
Below is the code snippet of the Model class. It contains ModuleDict
whose key is task_id
. The forward
function has task_id
as argument.
This is a control flow, similar to if-else
, since ModuleDict
internally uses for-loops
?
So, when tracing the model. Shall we trace ALL task_id
or tracing only one task_id (variable channel in code) is enough as shown below?
channel = torch.ones(1, dtype=torch.int64)
traced_script_module = torch.jit.trace(model, (premise, premise_length, hypotheses, hypotheses_length, channel))
output = traced_script_module(premise, premise_length, hypotheses, hypotheses_length, channel)
traced_script_module.save('deploy-trace-multitask.pt')
Code snippet for Model class’s definition
self._word_embedding = nn.Embedding(self.vocab_size,
self.embedding_dim,
padding_idx=padding_idx,
_weight=embeddings)
if self.dropout:
self._rnn_dropout = RNNDropout(p=self.dropout) #shared by all tasks
# self._rnn_dropout = nn.Dropout(p=self.dropout)
self._encoding = Seq2SeqEncoder(nn.LSTM,
self.embedding_dim,
self.hidden_size,
bidirectional=True)
#multi-task
self._attention = nn.ModuleDict({})
self._projection = nn.ModuleDict({})
self._classification = nn.ModuleDict({})
for channel in channels_list:
self.update(channel)
# Initialize all weights and biases in the model.
self.apply(_init_esim_weights)
def update(self, channel):
channel = str(channel)
self._attention.update({channel : SoftmaxAttention()})
self._projection.update({channel : nn.Sequential(nn.Linear(4*2*self.hidden_size, self.hidden_size), nn.ReLU())})
self._classification.update({channel : nn.Sequential(nn.Dropout(p=self.dropout),
nn.Linear(4*self.hidden_size,
self.hidden_size),
nn.Tanh(),
nn.Dropout(p=self.dropout),
nn.Linear(self.hidden_size,
self.num_classes))})
def forward(self,
premises,
premises_lengths,
hypotheses,
hypotheses_lengths,
channel_tensor): #must be a tensor
"""
Args:
premises: A batch of varaible length sequences of word indices
representing premises. The batch is assumed to be of size
(batch, premises_length).
premises_lengths: A 1D tensor containing the lengths of the
premises in 'premises'.
hypothesis: A batch of varaible length sequences of word indices
representing hypotheses. The batch is assumed to be of size
(batch, hypotheses_length).
hypotheses_lengths: A 1D tensor containing the lengths of the
hypotheses in 'hypotheses'.
Returns:
logits: A tensor of size (batch, num_classes) containing the
logits for each output class of the model.
probabilities: A tensor of size (batch, num_classes) containing
the probabilities of each output class in the model.
"""
channel_id = channel_tensor.item()
channel = str(channel_id)
premises_mask = get_mask(premises, premises_lengths).to(self.device)
hypotheses_mask = get_mask(hypotheses, hypotheses_lengths)\
.to(self.device)
embedded_premises = self._word_embedding(premises)
embedded_hypotheses = self._word_embedding(hypotheses)
if self.dropout:
embedded_premises = self._rnn_dropout(embedded_premises)
embedded_hypotheses = self._rnn_dropout(embedded_hypotheses)
encoded_premises = self._encoding(embedded_premises,
premises_lengths)
encoded_hypotheses = self._encoding(embedded_hypotheses,
hypotheses_lengths)
attended_premises, attended_hypotheses =\
self._attention[channel](encoded_premises, premises_mask,
encoded_hypotheses, hypotheses_mask)
""" rest of the code are omitted """