It should work if you apply this patch
diff --git a/model/swish.py b/model/swish.py
index 66adfa5..3f68678 100644
--- a/model/swish.py
+++ b/model/swish.py
@@ -1,10 +1,8 @@
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
-from torch._jit_internal import weak_module, weak_script_method
-@weak_module
class Swish(nn.Module):
def __init__(self, train_beta=False):
super(Swish, self).__init__()
@@ -13,7 +11,6 @@ class Swish(nn.Module):
else:
self.weight = 1.0
- @weak_script_method
def forward(self, input):
return input * torch.sigmoid(self.weight * input)