I have the following model:
class Model(nn.Module):
def __init__(self, encoder, decoder, args):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.args = args
def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs):
if not device_ids or len(device_ids) == 1:
return self(x, **kwargs)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(self, device_ids)
inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
replicas = replicas[:len(inputs)]
kwargs = kwargs[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
return nn.parallel.gather(outputs, output_device).mean()
def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
encoded = self.encoder(x)
out = self.decoder(tgt_seq, context=encoded, **kwargs)
return out
@torch.no_grad()
def generate(self, x: torch.Tensor, temperature: float = 0.25):
return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to('cpu'), self.args.max_seq_len,
eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)
class CustomVisionTransformer(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, *args, **kwargs):
super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs)
self.height, self.width = img_size
self.patch_size = patch_size
def forward_features(self, x):
B, c, h, w = x.shape
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
h, w = h//self.patch_size, w//self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embed[:, pos_emb_ind]
#x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
class CustomARWrapper(AutoregressiveWrapper):
def __init__(self, *args, **kwargs):
super(CustomARWrapper, self).__init__(*args, **kwargs)
@torch.no_grad()
def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)
if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)
if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
Using these inputs:
Input of Encoder: an image
output of encoder: image with positional embedding ex: tensor([1, positional_embedding, 256])
Input of Decoder is output of encoder
output of decoder is token_id using to mapping with tokenizer to produce final text
I’m totally new to ONNX so pls help me how to convert the model. Thanks