Need help with numpy to PyTorch conversion, with np.tile

I need to translate the following numpy code to PyTorch, in particular the np.tile part (ideally, PyTorch 0.4.0 compatible):


# k.size() = q.size()
batch_size, sent_len, vec_size = q.size()  # i.e. [150, 96, 120], also k is same
q_ = np.repeat(q, sent_len, axis=1)
k_ = np.tile(k, (sent_len, 1))
concat_vecs = np.stack([q_, k_], axis=2).reshape(batch_size, sent_len, sent_len, -1)

Any pointer appreciated. Thanks.

This should yield the same result:

q = np.random.randn(150, 96, 120)
k = np.random.randn(150, 96, 120)
batch_size, sent_len, vec_size = q.shape  # i.e. [150, 96, 120], also k is same
q_ = np.repeat(q, sent_len, axis=1)
k_ = np.tile(k, (sent_len, 1))
concat_vecs = np.stack([q_, k_], axis=2).reshape(batch_size, sent_len, sent_len, -1)


a = torch.from_numpy(q)
b = torch.from_numpy(k)
bs, s, v = a.size()
a_ = a.repeat(1, 1, s).view(bs, s*s, v)
b_ = b.repeat(1, s, 1)
concat_vec = torch.stack((a_, b_), 2).view(bs, s, s, -1)

(concat_vec == torch.from_numpy(concat_vecs)).all()

Let me show you the same code using einops:

from einops import repeat, rearrange

q_ = repeat(q, 'b sent channel -> b (sent sent2) channel', sent2=sent_len)
k_ = repeat(q, 'b sent channel -> b (sent2 sent) channel', sent2=sent_len)
concat_vecs = rearrange([q_, k_], 'qk b sentbysent channel -> b sentbysent (qk channel)')