Thank you for your help again. That’s a worthy modification. I can check that out.
BTW: When I posted this thread, I was just hoping for something easy like a stagger=True
kwarg that can be passed into for Conv2D! haha But I understand this is not a common use case for computer vision.
Currently I’ve implemented a rather lengthy routine with 4 masks, which I’ll share below,… and the result is that the new method (method=‘conv2d’) is no faster than my previous (method=‘fast’) routine! Whether on CPU, CUDA, or MPS, for multiple grid sizes…they take the same time. (I ran %%timeit) That is a bit disappointing. USER error: I was actually calling the same method each time!
The new ‘conv2d’ method is about 3 times faster when using CUDA on large grids.
I’ll share my code below (but with the understanding that it may turn you or other readers off, given it’s…verbosity!). The methods I’m comparing are my old one "fast"
vs the method "conv2d"
:
Currently these weights are not trainable as I’m just using F.conv2d instead of a Conv2D layer. …I’m proceeding by “baby steps”:
## Note there's a lot more code than the discuss.pytorch.org web system is showing you;
# it gets cut off but is scrollable...
#|export
def get_cb_indices(u, start=0):
"utility func. gets indices of 'red'/'black' checkerboard (cb) values of 2d array"
indices = torch.arange(u.shape[-2]*u.shape[-1], dtype=int).reshape(u.shape)
interior = indices[1:-1,1:-1]
jstride = u.shape[-2] # changing 'j'+/-1 changes this much in flattened indices
return interior.flatten()[start::2], jstride
def set_alternating_edges(conv_mask, start=1):
"""Alternate 1's and zeros along edges:
Note however that for Dirchlet BC's with u=0, solution is 0 along edges anyway
So using this or not may/should have no effect on solution."""
conv_mask[...,start:-1:2,0] = 1
conv_mask[...,start:-1:2,-1] = 1
conv_mask[...,0,start:-1:2] = 1
conv_mask[...,-1,start:-1:2] = 1
return conv_mask
def conv_pass(u, sigma, f, hm2, m4hm2, filters, mask, conv_mask, debug=False):
"perform one convolution pass (red or black)"
inputs = u.unsqueeze(0).unsqueeze(0) if len(u.shape) < 4 else u
if debug: print("inputs = \n",inputs)
resid = ( hm2*F.conv2d( conv_mask*inputs, filters, padding=1).squeeze() + sigma * u**2 - f )* mask
correction = resid / ( m4hm2 + 2.0 * sigma * u )
u -= correction* mask # newton step
return u, (resid**2).sum()
def smooth_error(uin, h, f, sigma, method='conv2d', debug=False,
red_mask=None, red_conv_mask=None, black_mask=None, black_conv_mask=None, filters=None):
"smoothes error via red-black gauss-seidel. old school without pytorch"
u = uin.clone() # unnecessary but kept just for repeatability
#print("u.shape = ",u.shape, u.dtype)
resid_norm = 0
hm2 = 1.0/(h*h)
m4hm2 = -4.0 * hm2
if method=='slow': # slow but sure
for rb_pass in range(2): # red-black gauss seidel
for j in range(1, u.shape[-1]-1):
ibump = (rb_pass + j) % 2 # alternates 0 and 1
for i in range(1+ibump, u.shape[-2]-1, 2):
resid_gs = hm2*( u[i+1,j] + u[i-1,j] + u[i,j+1] + u[i,j-1] - 4*u[i,j]) \
+ sigma * u[i,j]**2 - f[i,j]
dres_duij = m4hm2 + 2.0 * sigma * u[i,j]
correction = - resid_gs / dres_duij
#print("i,j, resid_gs, dres_duij, correction =",i,j, resid_gs, dres_duij, correction)
u[i,j] = u[i,j] + correction
resid_norm += resid_gs**2
elif method=='medium': # vectorized across j but not i; still written for readability
for rb_pass in range(2): # red-black gauss seidel
for i in range(1, u.shape[-2]-1): # hit all values of i, j's will skip every other via slicing
jstart = 1 + (1+rb_pass + i) % 2 # alternates 1 and 2; initialized to agree with slow method
uij = u[...,i, jstart:-1:2]
fij = f[...,i, jstart:-1:2]
uip1j = u[...,i+1, jstart:-1:2]
uim1j = u[...,i-1, jstart:-1:2]
uijp1 = u[...,i, jstart+1::2]
uijm1 = u[...,i, jstart-1:-2:2]
resid_gs = hm2*( uip1j + uim1j + uijp1 + uijm1 - 4*uij ) + (sigma * uij**2) - fij
dres_duij = m4hm2 + 2.0 * sigma * uij
uij += - resid_gs / dres_duij # newton step
resid_norm += (resid_gs**2).sum()
elif method=='fast': # vectorized across i and j
ufl, ffl = u.view(u.shape[-2]*u.shape[-1]), f.view(f.shape[-2]*f.shape[-1])
resid_norm = 0
for rb_pass in range(2): # red-black gauss seidel
idx, js = get_cb_indices(u, start=rb_pass)
resid_gs = hm2*( ufl[idx+1] + ufl[idx-1] + ufl[idx+js] + ufl[idx-js] - 4*ufl[idx]) + (sigma * ufl[idx]**2) - ffl[idx]
if debug: print(f"rb_pass = {rb_pass}, resid_gs =\n",resid_gs)
ufl[idx] -= resid_gs / ( m4hm2 + 2.0 * sigma * ufl[idx] ) # newton step
if debug: print(f"after rb_pass = {rb_pass}, u =\n",u)
resid_norm += (resid_gs**2).sum()
elif method=='conv2d':
assert red_mask is not None,"must pass in a mask now"
if debug:
print("red_mask =\n",red_mask)
print("black_mask =\n",black_mask)
print("red_conv_mask =\n",red_conv_mask)
print("black_conv_mask =\n",black_conv_mask)
u, resid_norm = conv_pass(u, sigma, f, hm2, m4hm2, filters, red_mask, red_conv_mask, debug=debug)
u, black_resid_norm = conv_pass(u, sigma, f, hm2, m4hm2, filters, black_mask, black_conv_mask, debug=debug)
resid_norm += black_resid_norm
if debug: print(f"after rb_pass = {1}, u =\n",u)
else:
print("Error: invalid method =",method)
resid_norm = torch.sqrt( resid_norm / ((u.shape[-2]-2)*(u.shape[-1]-2)) ).cpu().numpy()
if debug: print("end: u = \n",u)
return u, resid_norm
Here’s a bit of testing for a 7x7 run: (note there are a few variables undefined in this, such as sigma (=0), and “f”, but these could be anything.
# the pytorch conv2d way
red_mask = torch.zeros(u.shape, device=u.device, dtype=int)
red_idx, jstride = get_cb_indices(u, start=0)
red_mask.view(-1)[red_idx] = 1
black_mask = torch.zeros(u.shape, device=u.device, dtype=int)
black_idx, jstride = get_cb_indices(u, start=1)
black_mask.view(-1)[black_idx] = 1
red_conv_mask = red_mask.clone()
red_conv_mask.view(-1)[red_idx] = -4
red_conv_mask.view(-1)[black_idx] = 1
red_conv_mask = set_alternating_edges(red_conv_mask, start=1)
black_conv_mask = black_mask.clone()
black_conv_mask.view(-1)[black_idx] = -4
black_conv_mask.view(-1)[red_idx] = 1
black_conv_mask = set_alternating_edges(black_conv_mask, start=2)
unew, resnorm = smooth_error(utest.clone(), hx, rhs, sigma, method='conv2d',
red_mask=red_mask, black_mask=black_mask, red_conv_mask=red_conv_mask, black_conv_mask=black_conv_mask,
debug=True)
print(resnorm, unew)
unew, resnorm = smooth_error(unew, hx, rhs, sigma, method='conv2d',
red_mask=red_mask, black_mask=black_mask, red_conv_mask=red_conv_mask, black_conv_mask=black_conv_mask,
debug=True)
print(resnorm, unew)
Output is:
red_mask =
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0]])
black_mask =
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]])
red_conv_mask =
tensor([[ 0, 1, 0, 1, 0, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, 0, 1, 0, 1, 0]])
black_conv_mask =
tensor([[ 0, 0, 1, 0, 1, 0, 0],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 0, 0, 1, 0, 1, 0, 0]])
inputs =
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.9288, 0.7520, 0.1657, 0.4513, 0.0875, 0.0000],
[0.0000, 0.4457, 0.2441, 0.8293, 0.7338, 0.7791, 0.0000],
[0.0000, 0.9396, 0.8786, 0.0616, 0.7343, 0.8295, 0.0000],
[0.0000, 0.8389, 0.8395, 0.8926, 0.4192, 0.9531, 0.0000],
[0.0000, 0.1511, 0.1693, 0.5495, 0.3309, 0.5770, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, red resid =
tensor([[ -0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000],
[ -0.0000, -85.6999, -0.0000, 59.1877, -0.0000, 36.6385, -0.0000],
[ 0.0000, -0.0000, 84.2550, -0.0000, 9.7277, -0.0000, 0.0000],
[ -0.0000, -47.5618, -0.0000, 130.9272, -0.0000, -20.7869, -0.0000],
[ 0.0000, -0.0000, -6.0300, -0.0000, 59.2300, -0.0000, 0.0000],
[ -0.0000, 19.4751, -0.0000, -19.1107, -0.0000, -31.9364, -0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000]])
after rb_pass = 0, u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.7520, 0.5767, 0.4513, 0.3419, 0.0000],
[0.0000, 0.4457, 0.8292, 0.8293, 0.8013, 0.7791, 0.0000],
[0.0000, 0.6093, 0.8786, 0.9708, 0.7343, 0.6852, 0.0000],
[0.0000, 0.8389, 0.7976, 0.8926, 0.8305, 0.9531, 0.0000],
[0.0000, 0.2863, 0.1693, 0.4167, 0.3309, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
black pass: inputs =
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.7520, 0.5767, 0.4513, 0.3419, 0.0000],
[0.0000, 0.4457, 0.8292, 0.8293, 0.8013, 0.7791, 0.0000],
[0.0000, 0.6093, 0.8786, 0.9708, 0.7343, 0.6852, 0.0000],
[0.0000, 0.8389, 0.7976, 0.8926, 0.8305, 0.9531, 0.0000],
[0.0000, 0.2863, 0.1693, 0.4167, 0.3309, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, black resid =
tensor([[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -37.1177, -0.0000, 5.4710, -0.0000, 0.0000],
[ -0.0000, 8.1706, -0.0000, 12.0859, -0.0000, -37.8257, -0.0000],
[ 0.0000, -0.0000, 6.0323, -0.0000, 29.7111, -0.0000, 0.0000],
[ -0.0000, -51.2916, -0.0000, -2.8799, -0.0000, -61.3385, -0.0000],
[ 0.0000, -0.0000, 38.1930, -0.0000, 18.5927, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000]])
after rb_pass = 1, u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
[0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
[0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
[0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
[0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
end: u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
[0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
[0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
[0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
[0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
47.62565 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
[0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
[0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
[0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
[0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
red_mask =
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0]])
black_mask =
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]])
red_conv_mask =
tensor([[ 0, 1, 0, 1, 0, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, 0, 1, 0, 1, 0]])
black_conv_mask =
tensor([[ 0, 0, 1, 0, 1, 0, 0],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 1, -4, 1, -4, 1, -4, 1],
[ 0, 1, -4, 1, -4, 1, 0],
[ 0, 0, 1, 0, 1, 0, 0]])
inputs =
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3337, 0.4942, 0.5767, 0.4893, 0.3419, 0.0000],
[0.0000, 0.5024, 0.8292, 0.9132, 0.8013, 0.5165, 0.0000],
[0.0000, 0.6093, 0.9204, 0.9708, 0.9407, 0.6852, 0.0000],
[0.0000, 0.4827, 0.7976, 0.8726, 0.8305, 0.5271, 0.0000],
[0.0000, 0.2863, 0.4345, 0.4167, 0.4600, 0.3553, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, red resid =
tensor([[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000],
[ -0.0000, -7.2368, -0.0000, -4.8902, -0.0000, -8.0887, -0.0000],
[ 0.0000, -0.0000, -2.7072, -0.0000, 2.3605, -0.0000, 0.0000],
[ -0.0000, -9.2721, -0.0000, 11.2374, -0.0000, -17.3633, -0.0000],
[ 0.0000, -0.0000, -2.4866, -0.0000, -3.9787, -0.0000, 0.0000],
[ -0.0000, -3.2747, -0.0000, 13.4765, -0.0000, -10.6864, -0.0000],
[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000]])
after rb_pass = 0, u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2834, 0.4942, 0.5427, 0.4893, 0.2857, 0.0000],
[0.0000, 0.5024, 0.8104, 0.9132, 0.8177, 0.5165, 0.0000],
[0.0000, 0.5449, 0.9204, 1.0488, 0.9407, 0.5646, 0.0000],
[0.0000, 0.4827, 0.7804, 0.8726, 0.8029, 0.5271, 0.0000],
[0.0000, 0.2636, 0.4345, 0.5103, 0.4600, 0.2810, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
black pass: inputs =
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2834, 0.4942, 0.5427, 0.4893, 0.2857, 0.0000],
[0.0000, 0.5024, 0.8104, 0.9132, 0.8177, 0.5165, 0.0000],
[0.0000, 0.5449, 0.9204, 1.0488, 0.9407, 0.5646, 0.0000],
[0.0000, 0.4827, 0.7804, 0.8726, 0.8029, 0.5271, 0.0000],
[0.0000, 0.2636, 0.4345, 0.5103, 0.4600, 0.2810, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
after multiplying by mask, black resid =
tensor([[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -3.7085, -0.0000, -2.6546, -0.0000, 0.0000],
[-0.0000, -4.8040, -0.0000, 1.5001, -0.0000, -5.7729, -0.0000],
[ 0.0000, -0.0000, -0.8071, -0.0000, -1.9361, -0.0000, 0.0000],
[-0.0000, -3.7584, -0.0000, 4.5622, -0.0000, -8.0071, -0.0000],
[ 0.0000, -0.0000, 1.9289, -0.0000, -0.2971, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000]])
after rb_pass = 1, u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
[0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
[0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
[0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
[0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
end: u =
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
[0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
[0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
[0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
[0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
6.880748 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2834, 0.4685, 0.5427, 0.4709, 0.2857, 0.0000],
[0.0000, 0.4690, 0.8104, 0.9236, 0.8177, 0.4764, 0.0000],
[0.0000, 0.5449, 0.9148, 1.0488, 0.9272, 0.5646, 0.0000],
[0.0000, 0.4566, 0.7804, 0.9043, 0.8029, 0.4715, 0.0000],
[0.0000, 0.2636, 0.4479, 0.5103, 0.4579, 0.2810, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])