Conv2D: Any way to do "red-black"/"checkerboard" ordering?

Is there any way to get Conv2D to do a so-called “red-black” or “checkerboard” ordering? i.e. a stride of 2 along rows, and a stride of 1 along columns however each time you go to a new row you shift by 1? this would, for example, hit all the black squares of a checkboard, and then you could repeat it with a slight shift to hit all the red squares:

checkerboard image

Why? I’m converting to Pytorch, an old Fortran elliptic solver code that uses the red-black Gauss-Seidel method and despite all my attempts at writing things in a “pythonic”/“vectorized” format, my PyTorch-without-Conv2d is still a lot slower than the Fortan, even with GPU enabled (which the Fortran code can’t do).

Note that, as I described above, this is not just a stride of 2 everywhere. Doing it with a stride of 2 everywhere would require four conv passes instead of 2. Alternatively copying storage would result in the Jacobi method, which converges slower than G-S.

My current scheme involves getting a flattened view of the 2D tensor, and using the “red-black indices” in 1d (with some dilations to cut across rows) to drive the vectorized operations:

import torch

def get_cb_interior_indices(u, start=0):
    "utility func. gets indices of 'red'/'black' checkerboard (cb) values of 2d array"
    indices = torch.arange(u.shape[-2]*u.shape[-1], dtype=int).reshape(u.shape)
    interior = indices[1:-1,1:-1]
    jstride = u.shape[-2]  # changing 'j'+/-1 changes this much in flattened indices
    return interior.flatten()[start::2], jstride

u = torch.arange(49, dtype=torch.float).reshape(7,7)
print(u)
idx, jstride = get_cb_interior_indices(u)
print("idx =",idx,", jstride =",jstride)
idx, jstride = get_cb_interior_indices(u, start=1)
print("idx =",idx,", jstride =",jstride)

h = 1.0/(u.shape[0] - 1)
hm2, m4hm2 = 1/(h*h), -4/(h*2)
for rb_pass in range(2):   # red-black gauss seidel
    idx, jstride = get_cb_interior_indices(u, start=rb_pass)
    ufl = u.view(u.shape[-2]*u.shape[-1])
    resid_gs = (1/h/h)*( ufl[idx+1] + ufl[idx-1] + ufl[idx+jstride] + ufl[idx-jstride] - 4*ufl[idx])  + ufl[idx]**2 
    ufl[idx] += -resid_gs / ( m4hm2 )  # newton step, note that this is in-place i.e. it overwrites 

print(u)

Which yields an output of

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.],
        [ 7.,  8.,  9., 10., 11., 12., 13.],
        [14., 15., 16., 17., 18., 19., 20.],
        [21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34.],
        [35., 36., 37., 38., 39., 40., 41.],
        [42., 43., 44., 45., 46., 47., 48.]])
idx = tensor([ 8, 10, 12, 16, 18, 22, 24, 26, 30, 32, 36, 38, 40]) , jstride = 7
idx = tensor([ 9, 11, 15, 17, 19, 23, 25, 29, 31, 33, 37, 39]) , jstride = 7
tensor([[0.0000e+00, 1.0000e+00, 2.0000e+00, 3.0000e+00, 4.0000e+00, 5.0000e+00,
         6.0000e+00],
        [7.0000e+00, 1.3333e+01, 1.2075e+02, 1.8333e+01, 1.6308e+02, 2.4000e+01,
         1.3000e+01],
        [1.4000e+01, 2.3475e+02, 3.7333e+01, 3.5508e+02, 4.5000e+01, 3.3508e+02,
         2.0000e+01],
        [2.1000e+01, 6.2333e+01, 6.2108e+02, 7.2000e+01, 7.2708e+02, 8.2333e+01,
         2.7000e+01],
        [2.8000e+01, 7.6908e+02, 1.0500e+02, 1.0971e+03, 1.1733e+02, 9.4875e+02,
         3.4000e+01],
        [3.5000e+01, 1.4400e+02, 1.0611e+03, 1.5833e+02, 1.1828e+03, 1.7333e+02,
         4.1000e+01],
        [4.2000e+01, 4.3000e+01, 4.4000e+01, 4.5000e+01, 4.6000e+01, 4.7000e+01,
         4.8000e+01]])

…Which is fine and it’s not super-slow, but I’d prefer to implement it by using a Conv2D since that’s probably optimized like crazy – and a Conv2D layer could be made to have trainable weights.

I anticipate some replies that do not address the question directly, and instead try to offer completely different solutions, so I will respond here to what might be a few common replies:

  • “Could you just call the Fortran from Python?” No, ultimately I want to allow the convolution kernel to be trainable.
  • “Why are you trying to do this?” I think I’ve described in pretty heavy detail the motivation behind my question, but feel free to read an old paper of mine about it. An alternative response could be: because I want to.
  • “Could you use a Conv1D, similar to the flattened view you implemented?” Yea, maybe. I’m concerned about how to enforce the “sparsity” with respect to the j-strides: I don’t want a massive kernel that’s thousands of elements long containing mostly zeros. If sparse 1D convs are supported well, this might be the answer. I’d love to hear more about them if you can point me to any resources!
  • “You have an error or inefficiency in the above code. Let’s rewrite your (raw) Python code in this way…” …Yea thanks, but again, I hope to use the machinery of PyTorch layers so the weights can be trainable.
  • “Python will never be as fast a Fortran,…” Yea ok, fine. Again: Trainable weights.
  • “If you make the weights trainable you will destroy [various symmetry properties]…” Yes. I still want to try this anyway.
  • “Elliptic solves and image convolutions are related but should not be confused with each other…” …Yea, well, I’m interested in trying to use the machinery of PyTorch anyway.

Thanks in advance for your assistance!

One thing that you could try is to run a Conv2D with a stride of 1 making sure that the (odd sized) kernel has a checker board sparsity pattern. You need two convolutions one for the red squares and one for the black squares. After you have run the convolutions you would have to merge the two results to reform the checkerboard.

This approach does double the work than a true “red-black” stride since you are throwing away halve of the results of each convolution but considering how efficient the Conv2D implementation is you may well end up ahead.

The last part of the puzzle would be to make sure that the kernel sparsity pattern is conserved during training. The easiest I suppose would be to multiply the kernel weights by the sparsity pattern prior to convolving. Sort of a deterministic dropout.

Thank you @GCBallesteros for your reply! Yes, I think I could do this. To make sure I’m understanding you: for example given that my current kernel is 3x3 I could perhaps create a set of 4 4x4 stencils that are shifted a bit…, all implemented with a stride of 2 and initialized according to:

Red 1:                  Red 2:
| 0 | 1 | 0 | 0 |                  | 0 | 0 | 0 | 0 |
| 1 |-4 | 1 | 0 |                  | 0 | 0 | 1 | 0 |
| 0 | 1 | 0 | 0 |                  | 0 | 1 |-4 | 1 |
| 0 | 0 | 0 | 0 |                  | 0 | 0 | 1 | 0 |

Black 1:                  Black 2:
| 0 | 0 | 1 | 0 |                  | 0 | 0 | 0 | 0 |
| 0 | 1 |-4 | 1 |                  | 0 | 1 | 0 | 0 |
| 0 | 0 | 1 | 0 |                  | 1 |-4 | 1 | 0 |
| 0 | 0 | 0 | 0 |                  | 0 | 1 | 0 | 0 |

…Though as you say, reassembling the results could get tricky but it’s at least doable. Although: It would be really nice i could “share” the weights somehow. Part of my reason for trying to implement multigrid in PyTorch is to reduce the parameter count a bit. Let’s see if any other suggestions come in…

An alternative approach following a similar line of thought. Is to bring the sparsity pattern to the image instead of to the kernel.

# Assume you have a binary red_mask and a black_mask. If the input is fixed sized this could
# be precomputed. We also have 2 stride one Conv2D named red_conv and black_conv

output = red_conv(red_mask * input) * red_mask + black_conv(black_mask * input) * black_mask

Yet another option is to use indexing to extract the red checkerboard and the black checkerboard. Then you can run a stride 1 Conv2D on each of the split images which are halve the size and all red or black. Then you reform the checkerboard. This is even better because it doesn’t waste any computation and the sparsity pattern is implicitly taken care of. The other answer in this post may still end up being faster but only benchmarking can tell.

Ohhhh yea. This (the first one) is probably the thing to try first! I like it very much. Thank you. I’ll report back later on a bit of benchmarking.

The second idea I could do, but,I’d need to insert some “dummy” values because the alternating rows would not contain the same number of elements. … oh, or I guess I could have two “images”, one 2d array for even-numbered rows of red points, and another smaller 2d array for odd-numbered rows of red points,…

My pleasure :blush:. Let us know how it went

The second idea I could do, but,I’d need to insert some “dummy” values because the alternating rows would not contain the same number of elements

Indeed, I hadn’t thought about the case were you may have an odd number of rows/columns. My first thought here would be to add an extra row/col of padding. The perform the convolutions → reshape → discard padding.

Oh wait…, so one problem with this is that my conv kernel needs to be a 5-point stencil: filters=torch.tensor([[0,1,0],[1,-4,1],[0,1,0]]).unsqueeze(0),unsqueeze(0)
the middle point needs a -4 on it, but multiplying the input by the mask means I get a zero on the points I’m actually trying to update.

Perhaps I can use a uniform kernel of 1’s, and then make the mask include the -4 for all the other-colored (red/black points) that would otherwise be zeros.

…whew, no I’ll need two sets of masks for each red/black pass: one with 1’s and 0’s, and another with alternating 1’s, -4’s in the interior, and alternating 1’s, and 0’s along the edges. :-/ I’m doing it, but…wow, this is a fair amount of work!

I think I was misunderstanding the requirements. I thought the red/black parts of the computation where independent of each other, but really what is needed is to move a first neighbours stencil in a red/black strided fashion. When the stencil is centered on a black square it’s actually using also numbers from the red squares and viceversa right?

EDIT

One option is:

Use an unfold to generate all the patches. The output has shape (batch=1?, n_channels(1?)xprod_i{kernel_size_i}, patch_index) They are arranged linearly on the last index of the output so then you can do a stride two over the last dimension and apply the stencil (which would have to be arranged linearly). Finally unfold (which does an implicit sum). This is conceptually the same as doing the convolutions but less efficient than Conv2D.
In this case you would have to pad your images so that they have an odd number of rows and columns always so that the aforementioned stride 2 over the linear arrangement of patches always goes red/black/red/black…

I also noticed that

output = red_conv(red_mask * input) * red_mask + black_conv(black_mask * input) * black_mask

Can be modified to achieve what you want. For clarity assume that red_mask is ones on the red squares and viceversa. Then the red convolution would be.

output_red_squares = red_conv(black_mask * input) * black_msk + 4 * red_msk * input
output_black_squares = black_conv(red_mask * input) * red_msk + 4 * black_msk * input
output = output_red_squares + output_black_squares

The idea being that the convolutions account for the neighbours contributions on the stencil and the last addittion accounts for the central pixel (which is zeroed in the convolutional kernel) .

edit: missed a last multiplication times the input

1 Like

Thank you for your help again. That’s a worthy modification. I can check that out.

BTW: When I posted this thread, I was just hoping for something easy like a stagger=True kwarg that can be passed into for Conv2D! haha :rofl: But I understand this is not a common use case for computer vision.

Currently I’ve implemented a rather lengthy routine with 4 masks, which I’ll share below,… and the result is that the new method (method=‘conv2d’) is no faster than my previous (method=‘fast’) routine! Whether on CPU, CUDA, or MPS, for multiple grid sizes…they take the same time. (I ran %%timeit) That is a bit disappointing. USER error: I was actually calling the same method each time!

The new ‘conv2d’ method is about 3 times faster when using CUDA on large grids. :slight_smile:

I’ll share my code below (but with the understanding that it may turn you or other readers off, given it’s…verbosity!). The methods I’m comparing are my old one "fast" vs the method "conv2d":

Currently these weights are not trainable as I’m just using F.conv2d instead of a Conv2D layer. …I’m proceeding by “baby steps”:

## Note there's a lot more code than the discuss.pytorch.org web system is showing you; 
# it gets cut off but is scrollable...

#|export

def get_cb_indices(u, start=0):
    "utility func. gets indices of 'red'/'black' checkerboard (cb) values of 2d array"
    indices = torch.arange(u.shape[-2]*u.shape[-1], dtype=int).reshape(u.shape)
    interior = indices[1:-1,1:-1]
    jstride = u.shape[-2]  # changing 'j'+/-1 changes this much in flattened indices
    return interior.flatten()[start::2], jstride

def set_alternating_edges(conv_mask, start=1):
    """Alternate 1's and zeros along edges:  
    Note however that for Dirchlet BC's with u=0, solution is 0 along edges anyway
    So using this or not may/should have no effect on solution."""
    conv_mask[...,start:-1:2,0] = 1  
    conv_mask[...,start:-1:2,-1] = 1
    conv_mask[...,0,start:-1:2] = 1     
    conv_mask[...,-1,start:-1:2] = 1
    return conv_mask


def conv_pass(u, sigma, f, hm2, m4hm2, filters, mask, conv_mask, debug=False):
    "perform one convolution pass (red or black)"
    inputs = u.unsqueeze(0).unsqueeze(0) if len(u.shape) < 4 else u 
    if debug: print("inputs = \n",inputs)
    resid = (  hm2*F.conv2d( conv_mask*inputs, filters, padding=1).squeeze()  + sigma * u**2 - f  )* mask 
    correction =  resid /  ( m4hm2 + 2.0 * sigma * u )
    u -=  correction* mask # newton step
    return u, (resid**2).sum()
    
    

def smooth_error(uin, h, f, sigma, method='conv2d', debug=False, 
                 red_mask=None, red_conv_mask=None, black_mask=None, black_conv_mask=None, filters=None):
    "smoothes error via red-black gauss-seidel. old school without pytorch"
    u = uin.clone() # unnecessary but kept just for repeatability
    #print("u.shape = ",u.shape, u.dtype) 
    resid_norm = 0
    hm2   = 1.0/(h*h)
    m4hm2 = -4.0 * hm2
    if method=='slow':  # slow but sure
        for rb_pass in range(2):  # red-black gauss seidel
            for j in range(1, u.shape[-1]-1):
                ibump = (rb_pass + j) % 2    # alternates 0 and 1
                for i in range(1+ibump, u.shape[-2]-1, 2): 
                    resid_gs = hm2*( u[i+1,j] + u[i-1,j] + u[i,j+1] + u[i,j-1] - 4*u[i,j]) \
                        + sigma * u[i,j]**2 - f[i,j]
                    dres_duij = m4hm2 + 2.0 * sigma * u[i,j]
                    correction = - resid_gs / dres_duij
                    #print("i,j, resid_gs, dres_duij, correction =",i,j, resid_gs, dres_duij, correction)
                    u[i,j] = u[i,j] + correction   
                    resid_norm += resid_gs**2
    elif method=='medium':        # vectorized across j but not i; still written for readability
        for rb_pass in range(2):  # red-black gauss seidel
            for i in range(1, u.shape[-2]-1):       # hit all values of i, j's will skip every other via slicing
                jstart = 1 + (1+rb_pass + i) % 2    # alternates 1 and 2; initialized to agree with slow method
                uij   = u[...,i, jstart:-1:2]
                fij   = f[...,i, jstart:-1:2]
                uip1j = u[...,i+1, jstart:-1:2]
                uim1j = u[...,i-1, jstart:-1:2]
                uijp1 = u[...,i,   jstart+1::2]
                uijm1 = u[...,i,   jstart-1:-2:2]
                resid_gs = hm2*( uip1j + uim1j + uijp1 + uijm1 - 4*uij )  + (sigma * uij**2) - fij
                dres_duij = m4hm2 + 2.0 * sigma * uij
                uij  +=   - resid_gs / dres_duij  # newton step 
                resid_norm += (resid_gs**2).sum()
    elif method=='fast':   # vectorized across i and j
        ufl, ffl = u.view(u.shape[-2]*u.shape[-1]), f.view(f.shape[-2]*f.shape[-1])
        resid_norm = 0
        for rb_pass in range(2):   # red-black gauss seidel
            idx, js = get_cb_indices(u, start=rb_pass)
            resid_gs = hm2*( ufl[idx+1] + ufl[idx-1] + ufl[idx+js] + ufl[idx-js] - 4*ufl[idx])  + (sigma * ufl[idx]**2) - ffl[idx]
            if debug: print(f"rb_pass = {rb_pass}, resid_gs =\n",resid_gs)

            ufl[idx] -= resid_gs / ( m4hm2 + 2.0 * sigma * ufl[idx] )  # newton step 
            if debug: print(f"after rb_pass = {rb_pass}, u =\n",u)
            resid_norm += (resid_gs**2).sum()
    elif method=='conv2d':
        assert red_mask is not None,"must pass in a mask now"
        if debug:
            print("red_mask =\n",red_mask)
            print("black_mask =\n",black_mask)
            print("red_conv_mask =\n",red_conv_mask)
            print("black_conv_mask =\n",black_conv_mask)

        u, resid_norm = conv_pass(u, sigma, f, hm2, m4hm2, filters, red_mask, red_conv_mask, debug=debug)
        u, black_resid_norm = conv_pass(u, sigma, f, hm2, m4hm2, filters, black_mask, black_conv_mask, debug=debug)
        resid_norm += black_resid_norm
        
        if debug: print(f"after rb_pass = {1}, u =\n",u)
    else:
        print("Error: invalid method =",method)
        
    resid_norm = torch.sqrt( resid_norm / ((u.shape[-2]-2)*(u.shape[-1]-2))  ).cpu().numpy()
    if debug: print("end: u = \n",u)

    return u, resid_norm

Here’s a bit of testing for a 7x7 run: (note there are a few variables undefined in this, such as sigma (=0), and “f”, but these could be anything.

# the pytorch conv2d way

red_mask = torch.zeros(u.shape, device=u.device, dtype=int)
red_idx, jstride = get_cb_indices(u, start=0)
red_mask.view(-1)[red_idx] = 1 
black_mask = torch.zeros(u.shape, device=u.device, dtype=int)
black_idx, jstride = get_cb_indices(u, start=1)
black_mask.view(-1)[black_idx] = 1

red_conv_mask = red_mask.clone()
red_conv_mask.view(-1)[red_idx] = -4
red_conv_mask.view(-1)[black_idx] = 1
red_conv_mask = set_alternating_edges(red_conv_mask, start=1)

black_conv_mask = black_mask.clone()
black_conv_mask.view(-1)[black_idx] = -4
black_conv_mask.view(-1)[red_idx] = 1
black_conv_mask = set_alternating_edges(black_conv_mask, start=2)


unew, resnorm = smooth_error(utest.clone(), hx, rhs, sigma, method='conv2d', 
        red_mask=red_mask, black_mask=black_mask, red_conv_mask=red_conv_mask, black_conv_mask=black_conv_mask,
        debug=True)
print(resnorm, unew)
unew, resnorm = smooth_error(unew, hx, rhs, sigma, method='conv2d', 
        red_mask=red_mask, black_mask=black_mask, red_conv_mask=red_conv_mask, black_conv_mask=black_conv_mask,
        debug=True)
print(resnorm, unew)

Output is:

red_mask =
 tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0]])
black_mask =
 tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])
red_conv_mask =
 tensor([[ 0,  1,  0,  1,  0,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1,  0,  1,  0,  1,  0]])
black_conv_mask =
 tensor([[ 0,  0,  1,  0,  1,  0,  0],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 0,  0,  1,  0,  1,  0,  0]])
inputs = 
 tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.9288, 0.7520, 0.1657, 0.4513, 0.0875, 0.0000],
          [0.0000, 0.4457, 0.2441, 0.8293, 0.7338, 0.7791, 0.0000],
          [0.0000, 0.9396, 0.8786, 0.0616, 0.7343, 0.8295, 0.0000],
          [0.0000, 0.8389, 0.8395, 0.8926, 0.4192, 0.9531, 0.0000],
          [0.0000, 0.1511, 0.1693, 0.5495, 0.3309, 0.5770, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask,  red resid =
 tensor([[ -0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000],
        [ -0.0000, -85.6999,  -0.0000,  59.1877,  -0.0000,  36.6385,  -0.0000],
        [  0.0000,  -0.0000,  84.2550,  -0.0000,   9.7277,  -0.0000,   0.0000],
        [ -0.0000, -47.5618,  -0.0000, 130.9272,  -0.0000, -20.7869,  -0.0000],
        [  0.0000,  -0.0000,  -6.0300,  -0.0000,  59.2300,  -0.0000,   0.0000],
        [ -0.0000,  19.4751,  -0.0000, -19.1107,  -0.0000, -31.9364,  -0.0000],
        [  0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000]])
after rb_pass = 0, u =
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3337, 0.7520, 0.5767, 0.4513, 0.3419, 0.0000],
        [0.0000, 0.4457, 0.8292, 0.8293, 0.8013, 0.7791, 0.0000],
        [0.0000, 0.6093, 0.8786, 0.9708, 0.7343, 0.6852, 0.0000],
        [0.0000, 0.8389, 0.7976, 0.8926, 0.8305, 0.9531, 0.0000],
        [0.0000, 0.2863, 0.1693, 0.4167, 0.3309, 0.3553, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
black pass: inputs = 
 tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.3337, 0.7520, 0.5767, 0.4513, 0.3419, 0.0000],
          [0.0000, 0.4457, 0.8292, 0.8293, 0.8013, 0.7791, 0.0000],
          [0.0000, 0.6093, 0.8786, 0.9708, 0.7343, 0.6852, 0.0000],
          [0.0000, 0.8389, 0.7976, 0.8926, 0.8305, 0.9531, 0.0000],
          [0.0000, 0.2863, 0.1693, 0.4167, 0.3309, 0.3553, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, black resid =
 tensor([[  0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,   0.0000],
        [  0.0000,  -0.0000, -37.1177,  -0.0000,   5.4710,  -0.0000,   0.0000],
        [ -0.0000,   8.1706,  -0.0000,  12.0859,  -0.0000, -37.8257,  -0.0000],
        [  0.0000,  -0.0000,   6.0323,  -0.0000,  29.7111,  -0.0000,   0.0000],
        [ -0.0000, -51.2916,  -0.0000,  -2.8799,  -0.0000, -61.3385,  -0.0000],
        [  0.0000,  -0.0000,  38.1930,  -0.0000,  18.5927,  -0.0000,   0.0000],
        [  0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,   0.0000]])
after rb_pass = 1, u =
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
        [0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
        [0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
        [0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
        [0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
end: u = 
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
        [0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
        [0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
        [0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
        [0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
47.62565 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
        [0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
        [0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
        [0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
        [0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
red_mask =
 tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0]])
black_mask =
 tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
        [0, 0, 1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])
red_conv_mask =
 tensor([[ 0,  1,  0,  1,  0,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1,  0,  1,  0,  1,  0]])
black_conv_mask =
 tensor([[ 0,  0,  1,  0,  1,  0,  0],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 1, -4,  1, -4,  1, -4,  1],
        [ 0,  1, -4,  1, -4,  1,  0],
        [ 0,  0,  1,  0,  1,  0,  0]])
inputs = 
 tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
          [0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
          [0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
          [0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
          [0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask,  red resid =
 tensor([[  0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000],
        [ -0.0000,  -7.2368,  -0.0000,  -4.8902,  -0.0000,  -8.0887,  -0.0000],
        [  0.0000,  -0.0000,  -2.7072,  -0.0000,   2.3605,  -0.0000,   0.0000],
        [ -0.0000,  -9.2721,  -0.0000,  11.2374,  -0.0000, -17.3633,  -0.0000],
        [  0.0000,  -0.0000,  -2.4866,  -0.0000,  -3.9787,  -0.0000,   0.0000],
        [ -0.0000,  -3.2747,  -0.0000,  13.4765,  -0.0000, -10.6864,  -0.0000],
        [  0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000,  -0.0000,   0.0000]])
after rb_pass = 0, u =
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2834, 0.4942, 0.5427, 0.4893, 0.2857, 0.0000],
        [0.0000, 0.5024, 0.8104, 0.9132, 0.8177, 0.5165, 0.0000],
        [0.0000, 0.5449, 0.9204, 1.0488, 0.9407, 0.5646, 0.0000],
        [0.0000, 0.4827, 0.7804, 0.8726, 0.8029, 0.5271, 0.0000],
        [0.0000, 0.2636, 0.4345, 0.5103, 0.4600, 0.2810, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
black pass: inputs = 
 tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.2834, 0.4942, 0.5427, 0.4893, 0.2857, 0.0000],
          [0.0000, 0.5024, 0.8104, 0.9132, 0.8177, 0.5165, 0.0000],
          [0.0000, 0.5449, 0.9204, 1.0488, 0.9407, 0.5646, 0.0000],
          [0.0000, 0.4827, 0.7804, 0.8726, 0.8029, 0.5271, 0.0000],
          [0.0000, 0.2636, 0.4345, 0.5103, 0.4600, 0.2810, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, black resid =
 tensor([[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0000, -3.7085, -0.0000, -2.6546, -0.0000,  0.0000],
        [-0.0000, -4.8040, -0.0000,  1.5001, -0.0000, -5.7729, -0.0000],
        [ 0.0000, -0.0000, -0.8071, -0.0000, -1.9361, -0.0000,  0.0000],
        [-0.0000, -3.7584, -0.0000,  4.5622, -0.0000, -8.0071, -0.0000],
        [ 0.0000, -0.0000,  1.9289, -0.0000, -0.2971, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000]])
after rb_pass = 1, u =
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
        [0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
        [0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
        [0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
        [0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
end: u = 
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
        [0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
        [0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
        [0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
        [0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
6.880748 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
        [0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
        [0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
        [0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
        [0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

Note that I made an edit to the above: the ‘conv2d’ method is about 3x faster, on CUDA, for large grids!
Now to try a Conv2D layer…

Here is the unfold method if you want to benchmark it. I think it’s really clean and easy to understand once you understand how unfold works.

import torch
import matplotlib.pyplot as plt
from torch.nn.functional import unfold, fold

# change the second 33 to an even number to see what happens
# when the number of cols is not odd
u = torch.ones([33, 33], dtype=torch.float32)
kernel = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]]).flatten()

# generate the patches
# we assume a square kernel
unfolded_patches = unfold(u.unsqueeze(0).unsqueeze(0), (3, 3), padding=(3-1)//2)

# some prints to check how things work
print(unfolded_patches.shape)
print(u.shape)
print(33 * 33)

# Note the stride of 2 to only go over one color
# THIS ONLY WORKS FOR ODD NUMBERS OF COLUMNS!
# and the broadcasting enabled by the None's
# change ::2 to 1::2 to change color
unfolded_patches[:,:,::2] = unfolded_patches[:,:,::2] * kernel[None,:,None]

# the sum below is the summation part of a convolution
collapse_convolution = unfolded_patches.sum(axis=1).squeeze()

# finally reshape
reshaped_output = collapse_convolution.reshape(u.shape)

# just some plots to check that everything works as expected
plt.figure()
plt.imshow(reshaped_output)
plt.colorbar()

That’s enough procrastination from me for a day xD

Thanks very much. I’ll give that a shot tomorrow!

I found some time this afternoon to package the unfold method neatly. It runs a 4Megapixel image on a 2080Ti card in 136us per convolution. The even number of columns case is slightly slower but such is life. It can be generalized to non-square kernels without too much trouble, but the padding in the unfold would need to be adjusted.

The kernel parameters are also trainable if you set them to require_gradients.

import torch
from torch.nn.functional import unfold
from math import sqrt
import matplotlib.pyplot as plt

class RedBlackConv(torch.nn.Module):
    def __init__(self, red_kernel, black_kernel):
        super().__init__()
        self.red_kernel = torch.nn.Parameter(self.validate_kernel(red_kernel), requires_grad=False)
        self.black_kernel = torch.nn.Parameter(self.validate_kernel(black_kernel), requires_grad=False)
        
        if self.red_kernel.shape != self.black_kernel.shape:
            raise ValueError("Both kernels must be the same shape.")
        
        self.kernel_size = int(sqrt(self.red_kernel.shape[0]))
            

    def validate_kernel(self, kernel):
        if kernel.ndim != 2:
            raise ValueError("Kernel should be 2D")
        elif (kernel.shape[0] % 2 != 1) or (kernel.shape[1] % 2 != 1):
            raise ValueError("Kernel should have dimensions (odd, odd)")
        elif kernel.shape[0] != kernel.shape[1]:
            raise ValueError("Only square kernels are currently supported")
        
        return kernel.flatten()
    
    def forward(self, x):
        x_shape = x.shape
        
        if x.shape[1] % 2 != 1:
            is_even = True
        else:
            is_even = False
            
        if is_even:
            # We don't have an odd number of cols the input needs to get
            # padded with an extra column
            all_zeros_column = torch.zeros(x.shape[0], device=x.device)
            x = torch.column_stack([x, all_zeros_column])
        
        patches = unfold(
            x.unsqueeze(0).unsqueeze(0),
            (self.kernel_size, self.kernel_size),
            padding=(self.kernel_size-1)//2,
        )
        
        # Do red squares
        patches[:, :, ::2] = patches[:, :, ::2] * self.red_kernel[None, :, None]
        
        # Do black squares
        patches[:, :, 1::2] = patches[:, :, 1::2] * self.black_kernel[None, :, None]
        
        output = patches.sum(axis=1).squeeze().reshape(x.shape)
        
        # Make sure to remove the extra column if input had an even number
        # of columns
        if is_even:
            output = output[:,:-1]
        
        return output
        

# time to test it!!!
        
# The kernels below should produce zeros on the red squares
# and 4 on the black ones if the input is all ones
red_black_conv = RedBlackConv(
    torch.tensor(
        [
            [0,-1,0],
            [-1,4,-1],
            [0,-1,0],
        ],
        device="cuda",
    ),
    torch.tensor(
        [
            [0,1,0],
            [1,0,1],
            [0,1,0],
        ],
        device="cuda",
    ),
)

x = torch.ones([33, 32], dtype=torch.float32, device="cuda")
output = red_black_conv(x).detach().cpu().numpy()

plt.figure()
plt.imshow(output)
plt.colorbar()

Wonderful! Thank you. I will try this out today…