Tip: using keras compatible tensor dot product and broadcasting ops

I find theano’s dot() and broadcasting of basic operators very convenient (ditto for keras, which is designed to be fully compatible with the theano API for these functions). It saves a lot of unsqueeze()ing and expand_as()ing and makes life a lot easier IMO. It also makes it easier to port code from theano and keras to pytorch.

In summary, dot() handles pretty much any sized tensor arguments and does a dot product of the last axis of the first argument with the 2nd last axis of the 2nd argument. And for +*-/ it broadcasts any empty leading or unit axes as needed to make the arguments compatible. In case anyone is interested, here is a pytorch version of theano’s dot() and broadcasted operators - hope some folks find it useful!

def align(x, y, start_dim=2):
    xd, yd = x.dim(), y.dim()
    if xd > yd:
        for i in range(xd - yd): y = y.unsqueeze(0)
    elif yd > xd:
        for i in range(yd - xd): x = x.unsqueeze(0)
    xs = list(x.size())
    ys = list(y.size())
    nd = len(ys)
    for i in range(start_dim, nd):
        td = nd-i-1
        if   ys[td]==1: ys[td] = xs[td]
        elif xs[td]==1: xs[td] = ys[td]
    return x.expand(*xs), y.expand(*ys)

def dot(x, y):
    x, y = align(x, y)
    assert(1<y.dim()<5)
    if y.dim() == 2:
        return x.mm(y)
    elif y.dim() == 3: 
        return x.bmm(y)
    else:
        xs,ys = x.size(), y.size()
        res = torch.zeros(*(xs[:-1] + (ys[-1],)))
        for i in range(xs[0]): res[i] = x[i].bmm(y[i])
        return res

def aligned_op(x,y,f):
    x, y = align(x,y,0)
    return f(x, y)

def add(x, y): return aligned_op(x, y, operator.add)
def sub(x, y): return aligned_op(x, y, operator.sub)
def mul(x, y): return aligned_op(x, y, operator.mul)
def div(x, y): return aligned_op(x, y, operator.truediv)

And here are some tests / examples:

def Arr(*sz): return torch.randn(sz)

m = Arr(3, 2)
v = Arr(2)
b = Arr(4,3,2)
t = Arr(5,4,3,2)

mt = m.transpose(0,1)
bt = b.transpose(1,2)
tt = t.transpose(2,3)

def check_eq(x,y): assert(torch.equal(x,y))

check_eq(dot(m,mt),m.mm(mt))
check_eq(dot(v,mt), v.unsqueeze(0).mm(mt))
check_eq(dot(b,bt),b.bmm(bt))
check_eq(dot(b,mt),b.bmm(mt.unsqueeze(0).expand_as(bt)))

exp = t.view(-1,3,2).bmm(tt.contiguous().view(-1,2,3)).view(5,4,3,3)
check_eq(dot(t,tt),exp)

check_eq(add(m,v),m+v.unsqueeze(0).expand_as(m))
check_eq(add(v,m),m+v.unsqueeze(0).expand_as(m))
check_eq(add(m,t),t+m.unsqueeze(0).unsqueeze(0).expand_as(t))
check_eq(sub(m,v),m-v.unsqueeze(0).expand_as(m))
check_eq(mul(m,v),m*v.unsqueeze(0).expand_as(m))
check_eq(div(m,v),m/v.unsqueeze(0).expand_as(m))
4 Likes

I’ve made some minor changes to this code to make it a bit faster - here’s the updated version (tests from above will still work fine):

def unit_prefix(x, n=1):
    for i in range(n): x = x.unsqueeze(0)
    return x

def align(x, y, start_dim=2):
    xd, yd = x.dim(), y.dim()
    if xd > yd: y = unit_prefix(y, xd - yd)
    elif yd > xd: x = unit_prefix(x, yd - xd)

    xs, ys = list(x.size()), list(y.size())
    nd = len(ys)
    for i in range(start_dim, nd):
        td = nd-i-1
        if   ys[td]==1: ys[td] = xs[td]
        elif xs[td]==1: xs[td] = ys[td]
    return x.expand(*xs), y.expand(*ys)

def dot(x, y):
    assert(1<y.dim()<5)
    x, y = align(x, y)
    
    if y.dim() == 2: return x.mm(y)
    elif y.dim() == 3: return x.bmm(y)
    else:
        xs,ys = x.size(), y.size()
        res = torch.zeros(*(xs[:-1] + (ys[-1],)))
        for i in range(xs[0]): res[i].baddbmm_(x[i], (y[i]))
        return res
9 Likes

Why not merge it in pytorch?

The latest version of pytorch now supports proper broadcasting, so you don’t have to use my hacky version any more :slight_smile:

1 Like