Let’s say I wanna build a custom metric, so in that case, I could just simply just write a function and use accordingly in my training loop.
Can I rather define metric as a class inherited from nn.Module
and there define the forward method for the metric definition?
I know I can, but is it beneficial in any way (if it is) or just a function definition is fine.
If your metric is stateless (i.e. you don’t want to store arguments or other parameters), then you could just define a method. Otherwise, a custom class derived from nn.Module
could store some arguments e.g. reductions
as an internal attribute.
Also, have a look at this post I’ve written some time ago where I discuss the advantages of both approaches.
1 Like