I am working on disentangled representations employing a final ‘decoder’ that is a differentiable renderer - a simple line drawing algorithm.
The final stage of the network is a line renderer that takes a 4 dimensional latent representation, trained on synthetically generated images of edges. This portion of the net works well, and when coupled with a typical auto-encoder ‘decoder’ stage produces strong results - even with just 4 values as the latent rep.
The line renderer/decoder is bresenham’s algorithm. It has been tested and works fine.
My problem is that the network no longer learns with the line renderer.
I am uncertain of a few things:
Are errors being propagated through the renderer? How can I tell?
Are gradients NOT being collected on operations in the renderer? How can I be sure?
Here is the exact code for the final stage - it is separated into a separate function in preparation for moving it to warp. I build the image in a numpy array to avoid all the leaf and in place operation issues:
def line_draw(inp: Tensor, args): img = np.zeros((args.batch_size, 3, args.patch_size, args.patch_size)) for b, i in enumerate(inp[:]): x0 = round(i.item()) y0 = round(i.item()) x1 = round(i.item()) y1 = round(i.item()) dx = abs(x1-x0) sx = 1 if x0<x1 else -1 dy = -abs(y1-y0) sy = 1 if y0<y1 else -1 error = dx+dy while True: img[b,:,y0,x0] = 1 if x0==x1 and y0==y1: break e2 = 2*error if e2>=dy: if x0==x1: break error = error+dy x0 = x0+sx if e2<=dx: if y0==y1: break error = error+dx y0 = y0+sy timg = torch.from_numpy(img).to(torch.float32) timg.requires_grad=True return timg class VAEDecoder(pl.LightningModule): def __init__(self, args): super(VAEDecoder, self).__init__() self.save_hyperparameters() self.args = args def forward(self, inp): return line_draw(inp, self.args) def __call__(self, inp): return self.forward(inp)
Any insight and help appreciated,