In order to be able to compile with Torchtrace beta distribition,
am using this code
class Algo2(torch.nn.Module, ABC):
def __init__(self, ):
self.aa = [1]*10
self.bb = [2]*10
def get_beta_samples(self, algoid=0) -> Tensor:
one = torch.ones(1)
samples = []
kk = 0
for i, j in zip(one + self.aa[algoid], one + self.bb[algoid]):
samples.append( torch._sample_dirichlet(torch.tensor([i, j]))[:1] )
kk = kk +1
samples = torch.cat(samples, 0)
return samples
X1= torch.tensor(6)
trace_model = torch.jit.trace_module(model1, {'get_samples': X1,})
Output is always the same.
However it says:
TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
How to fix the issue ?
One cannot use torch.distribution.beta for TorchScript.