Unable to import 'weak_module'

It should work if you apply this patch

diff --git a/model/swish.py b/model/swish.py
index 66adfa5..3f68678 100644
--- a/model/swish.py
+++ b/model/swish.py
@@ -1,10 +1,8 @@
 import torch
 import torch.nn as nn
 from torch.nn.parameter import Parameter
-from torch._jit_internal import weak_module, weak_script_method
 
 
-@weak_module
 class Swish(nn.Module):
     def __init__(self, train_beta=False):
         super(Swish, self).__init__()
@@ -13,7 +11,6 @@ class Swish(nn.Module):
         else:
             self.weight = 1.0
 
-    @weak_script_method
     def forward(self, input):
         return input * torch.sigmoid(self.weight * input)
1 Like