Estimator should be an estimator implementing 'fit' method,

Change to criterion = torch.nn.CrossEntropyLoss,
Similar issue here.