Hey
I am trying to run a simple enc/dec architecture, consisting of
- a frozen encoder
- a simple feed-forward on the outputs of (1)
- a trainable decoder on the outputs of (2)
- a simple feed-forward on the outputs of (3)
The fw/bw pipeline works fine on the first batch, yet breaks on the next backwards() call.
My train_batch() function returns nothing, and no intermediate results are stored.
Help would be much appreciated.
def __init__(self, num_embeddings: int, dim: int = 768, device: str = 'cpu'):
super(Parser, self).__init__()
self.dim = dim
self.num_embeddings = num_embeddings
self.device = device
self.encoder = RobertaModel.from_pretrained('pdelobelle/robBERT-base')
self.encoder_projection = FFN(d_model=dim, d_ff=2 * dim).to(device)
self.decoder = make_decoder(num_layers=2, num_heads=6, d_model=self.dim, d_k=self.dim, d_v=self.dim,
d_intermediate=self.dim, dropout=0.1).to(device)
self.embedding_matrix = Parameter(torch.rand(num_embeddings + 1, dim, device=device) * 0.02,
requires_grad=True) * 0.02
self.atom_embedder = lambda x: functional.embedding(x, self.embedding_matrix, padding_idx=0,
scale_grad_by_freq=True)
self.predictor = lambda x: x @ (self.embedding_matrix.transpose(1, 0) + 1e-10)
self.pos_encoder = PositionalEncoder(0.1)
def forward(self, lexical_token_ids: LongTensor, symbol_ids: LongTensor, pos_idxes: List[List[LongTensor]],
neg_idxes: List[List[LongTensor]]) -> Any:
pass
def encode(self, lexical_token_ids: LongTensor, encoder_mask: LongTensor) -> Tensor:
with torch.no_grad():
encoder_output, _ = self.encoder(lexical_token_ids, attention_mask=encoder_mask)
encoder_output = encoder_output.to(self.device)
return self.encoder_projection(encoder_output)
def make_output_repr(self, lexical_token_ids: LongTensor, symbol_ids: LongTensor) -> Tensor:
b, s_in = lexical_token_ids.shape
s_out = symbol_ids.shape[1]
encoder_mask = self.make_encoder_mask(lexical_token_ids)
encoder_output = self.encode(lexical_token_ids, encoder_mask)
decoder_mask = self.make_decoder_mask(b=b, n=symbol_ids.shape[1])
atom_embeddings = self.atom_embedder(symbol_ids.to(self.device))
pos_encodings = self.pos_encoder(b, s_out, self.dim, 1024).to(self.device)
atom_embeddings = atom_embeddings + pos_encodings
extended_encoder_mask = encoder_mask.view(b, 1, s_in).repeat(1, s_out, 1).to(self.device)
return self.decoder((encoder_output, extended_encoder_mask, atom_embeddings, decoder_mask))[2]
def predict(self, symbol_reprs: Tensor) -> Tensor:
return self.predictor(symbol_reprs)
@staticmethod
def make_mask(inps: LongTensor, padding_id: int) -> LongTensor:
mask = torch.ones_like(inps)
mask[inps == padding_id] = 0
return mask
def make_encoder_mask(self, lexical_ids: LongTensor) -> LongTensor:
return self.make_mask(lexical_ids, 1)
def make_decoder_mask(self, b: int, n: int) -> LongTensor:
return torch.triu(torch.ones(b, n, n), diagonal=1).to(self.device)
def train_batch(self, samples: List[Sample], atom_map: Mapping[str, int], tokenizer: Tokenizer,
optimizer: Optimizer, max_difficulty: int = 20):
words, types, pos_idxes, neg_idxes = samples_to_batch(samples, atom_map, tokenizer)
output_reprs = self.make_output_repr(words, types)
# supertagging
type_predictions = self.predict(output_reprs)
type_predictions = type_predictions[:, :-1].permute(0, 2, 1).reshape(-1, self.num_embeddings + 1)
types = types[:, 1:].to(self.device).reshape(-1)
supertagging_loss = functional.cross_entropy(type_predictions, types)
supertagging_loss.backward()
print('Backpropped.')
optimizer.step()
optimizer.zero_grad()
return
def train_epoch(self, dataloader: DataLoader, atom_map: Mapping[str, int], tokenizer: Tokenizer,
optimizer: Optimizer):
for i, samples in enumerate(dataloader):
self.train_batch(samples, atom_map, tokenizer, optimizer)```