Fx example that doesn't involved writing you own optimizer

Spent hours with chat openai and bard.google without luck trying to get a working example. They kept making the same mistakes as they circled back to earlier things that didn’t work.

If I read the pytorch blog on “Optimizing Production PyTorch Models’ Performance with Graph Transformations” I see things like “replacing ReLU with GELU” and “embedding tables” and “horizontal fuson”.

I don’t want to figure out how to do my own optimization. I thought fx would do it for me.

Assuming I had a Net with some Linear layers and activation how would I optimize it to maximizing the performance for training. I was trying to add this optimization to a MNIST training example.

Assuming the following first two steps are correct:

model = Model()
traced_model = fx.symbolic_trace(model)

I can just wrap it with fx.GraphModule() and use it. However, I thought I could call built in optimizations on the graph like ‘dead_code_elimination’ or ‘DeadCodeElimination’ or other things being suggested by chatgpt like:

fx.passes.constant_propagation,
fx.passes.peephole_optimize,
fx.passes.remove_dropout_and_batchnorm

I can’t find any example which uses fx to optimize a graph other than the simplest stuff that tracing and wrapping does.

1 Like

Now I’m even more confused. I just read the torch.fx overview. It consists of three main components: a symbolic tracer, an intermediate representation , and Python code generation.

No where does it say fx provides any optimizations whatsoever yet I could swear I’ve seen that hyped about it. The code generation is a transformation toolkit and with it you can do your own optimizations if you were some kind of GPU NN wizard.

Where are the common optimizations I would expect to find in a general purpose graph optimizer targeted at Neural Networks?

The fx symbolic tracer isn’t used all that much in the pt 2.0 world, what has survived from fx is fx the intermediate representation.

FWIW ChatGPT hasn’t been all that useful for me to understand PyTorch internals, it frequently hallucinates but

But let’s say you want to do dead code elimination, you can see the config variable in _inductor pytorch/config.py at main · pytorch/pytorch · GitHub and see how it’s used pytorch/torch/_inductor/fx_passes at main · pytorch/pytorch · GitHub

There’s even a whole folder of passes here pytorch/torch/_inductor/fx_passes at main · pytorch/pytorch · GitHub

Thanks.
FYI, yesterday I learned how to train. So I grabbed some images to train on and started looking for some Java program I wrote long ago to scale images.
Then I just asked ChatGPT to give me a Python program to go through a directory of images to scale them to 512x512 and quickly converted the directory.
Then I still had an “aspect ratio” problem, so on a hunch, for another prog to go through the directory and print the sizes. The larger dimension was now 512 but there was no padding for the smaller dimension.
So I asked it to pad the images and the program kept throwing an error I didn’t understand when pasting the image into the 512x512 target.
I worked with it to debug the problem and discovered that param two for the paste was just supposed to be (left, top) and not (left, top, right, bottom).
Job done!