I want to store a small set of useful-to-have constants in my model:
ZERO = torch.tensor( float(0.0) )
ONE = torch.tensor( float(0.0) )
INF = torch.tensor( float('inf') )
NAN = torch.tensor( float('nan') )
TRUE = torch.tensor( bool(True) )
These are often needed when e.g. perfoming operations like torch.where(mask, tensor, ZERO)
What is the best way to store these kinds of constants in my model? I have the follwoing desiderata:
- They are cast to the correct device/dtype when using
model.to
- They are collected in a namespace
model.constants
- I do not need to copy-paste the whole code in each model every time
- The list is extendable on a per-model basis
- It is fully compatible with JIT
Currently, I am doing
class model(nn.Module):
ZERO: torch.Tensor
def __init__(self):
self.register_buffer('ZERO', torch.tensor(0.0))
Which satisfies (1, 4, 5) but violates (2, 3). Any better ideas?