I need to translate the following numpy code to PyTorch, in particular the np.tile part (ideally, PyTorch 0.4.0 compatible):
# k.size() = q.size()
batch_size, sent_len, vec_size = q.size() # i.e. [150, 96, 120], also k is same
q_ = np.repeat(q, sent_len, axis=1)
k_ = np.tile(k, (sent_len, 1))
concat_vecs = np.stack([q_, k_], axis=2).reshape(batch_size, sent_len, sent_len, -1)
Any pointer appreciated. Thanks.
This should yield the same result:
q = np.random.randn(150, 96, 120)
k = np.random.randn(150, 96, 120)
batch_size, sent_len, vec_size = q.shape # i.e. [150, 96, 120], also k is same
q_ = np.repeat(q, sent_len, axis=1)
k_ = np.tile(k, (sent_len, 1))
concat_vecs = np.stack([q_, k_], axis=2).reshape(batch_size, sent_len, sent_len, -1)
a = torch.from_numpy(q)
b = torch.from_numpy(k)
bs, s, v = a.size()
a_ = a.repeat(1, 1, s).view(bs, s*s, v)
b_ = b.repeat(1, s, 1)
concat_vec = torch.stack((a_, b_), 2).view(bs, s, s, -1)
(concat_vec == torch.from_numpy(concat_vecs)).all()
Let me show you the same code using einops:
from einops import repeat, rearrange
q_ = repeat(q, 'b sent channel -> b (sent sent2) channel', sent2=sent_len)
k_ = repeat(q, 'b sent channel -> b (sent2 sent) channel', sent2=sent_len)
concat_vecs = rearrange([q_, k_], 'qk b sentbysent channel -> b sentbysent (qk channel)')