I find theano’s dot() and broadcasting of basic operators very convenient (ditto for keras, which is designed to be fully compatible with the theano API for these functions). It saves a lot of unsqueeze()ing and expand_as()ing and makes life a lot easier IMO. It also makes it easier to port code from theano and keras to pytorch.
In summary, dot() handles pretty much any sized tensor arguments and does a dot product of the last axis of the first argument with the 2nd last axis of the 2nd argument. And for +*-/ it broadcasts any empty leading or unit axes as needed to make the arguments compatible. In case anyone is interested, here is a pytorch version of theano’s dot() and broadcasted operators - hope some folks find it useful!
def align(x, y, start_dim=2):
xd, yd = x.dim(), y.dim()
if xd > yd:
for i in range(xd - yd): y = y.unsqueeze(0)
elif yd > xd:
for i in range(yd - xd): x = x.unsqueeze(0)
xs = list(x.size())
ys = list(y.size())
nd = len(ys)
for i in range(start_dim, nd):
td = nd-i-1
if ys[td]==1: ys[td] = xs[td]
elif xs[td]==1: xs[td] = ys[td]
return x.expand(*xs), y.expand(*ys)
def dot(x, y):
x, y = align(x, y)
assert(1<y.dim()<5)
if y.dim() == 2:
return x.mm(y)
elif y.dim() == 3:
return x.bmm(y)
else:
xs,ys = x.size(), y.size()
res = torch.zeros(*(xs[:-1] + (ys[-1],)))
for i in range(xs[0]): res[i] = x[i].bmm(y[i])
return res
def aligned_op(x,y,f):
x, y = align(x,y,0)
return f(x, y)
def add(x, y): return aligned_op(x, y, operator.add)
def sub(x, y): return aligned_op(x, y, operator.sub)
def mul(x, y): return aligned_op(x, y, operator.mul)
def div(x, y): return aligned_op(x, y, operator.truediv)
And here are some tests / examples:
def Arr(*sz): return torch.randn(sz)
m = Arr(3, 2)
v = Arr(2)
b = Arr(4,3,2)
t = Arr(5,4,3,2)
mt = m.transpose(0,1)
bt = b.transpose(1,2)
tt = t.transpose(2,3)
def check_eq(x,y): assert(torch.equal(x,y))
check_eq(dot(m,mt),m.mm(mt))
check_eq(dot(v,mt), v.unsqueeze(0).mm(mt))
check_eq(dot(b,bt),b.bmm(bt))
check_eq(dot(b,mt),b.bmm(mt.unsqueeze(0).expand_as(bt)))
exp = t.view(-1,3,2).bmm(tt.contiguous().view(-1,2,3)).view(5,4,3,3)
check_eq(dot(t,tt),exp)
check_eq(add(m,v),m+v.unsqueeze(0).expand_as(m))
check_eq(add(v,m),m+v.unsqueeze(0).expand_as(m))
check_eq(add(m,t),t+m.unsqueeze(0).unsqueeze(0).expand_as(t))
check_eq(sub(m,v),m-v.unsqueeze(0).expand_as(m))
check_eq(mul(m,v),m*v.unsqueeze(0).expand_as(m))
check_eq(div(m,v),m/v.unsqueeze(0).expand_as(m))