Getting serialization error when using Ray Tune

I am new to ray.tune and I am trying to use it to tune two hyperparameters: learning_rate and weight decay. I have mostly followed the PyTorch tutorial for ray.tune.
I get the following error message:
After the error message, I share my code as well.

================================================================================
Checking Serializability of <class 'ray.tune.trainable.function_trainable.wrap_function.<locals>.ImplicitFunc'>
================================================================================
!!! FAIL serialization: cannot pickle 'Event' object
    Serializing '__init__' <function Trainable.__init__ at 0x7fbe1a9ed550>...
    Serializing '__repr__' <function wrap_function.<locals>.ImplicitFunc.__repr__ at 0x7fbdf1d94ee0>...
    Serializing '_close_logfiles' <function Trainable._close_logfiles at 0x7fbe1a9f13a0>...
    Serializing '_create_checkpoint_dir' <function FunctionTrainable._create_checkpoint_dir at 0x7fbe1a9878b0>...
    Serializing '_create_logger' <function Trainable._create_logger at 0x7fbe1a9f1280>...
    Serializing '_export_model' <function Trainable._export_model at 0x7fbe1a9f1c10>...
    Serializing '_implements_method' <function Trainable._implements_method at 0x7fbe1a9f1ca0>...
    Serializing '_maybe_load_from_cloud' <function Trainable._maybe_load_from_cloud at 0x7fbe1a9edd30>...
    Serializing '_maybe_save_to_cloud' <function Trainable._maybe_save_to_cloud at 0x7fbe1a9edca0>...
    Serializing '_open_logfiles' <function Trainable._open_logfiles at 0x7fbe1a9f1310>...
    Serializing '_report_thread_runner_error' <function FunctionTrainable._report_thread_runner_error at 0x7fbe1a987ca0>...
    Serializing '_restore_from_checkpoint_obj' <function FunctionTrainable._restore_from_checkpoint_obj at 0x7fbe1a987a60>...
    Serializing '_start' <function FunctionTrainable._start at 0x7fbe1a9875e0>...
    Serializing '_storage_path' <function Trainable._storage_path at 0x7fbe1a9ed670>...
    Serializing '_trainable_func' <function wrap_function.<locals>.ImplicitFunc._trainable_func at 0x7fbdf1da31f0>...
    !!! FAIL serialization: cannot pickle 'Event' object
    Detected 3 global variables. Checking serializability...
        Serializing 'partial' <class 'functools.partial'>...
        Serializing 'inspect' <module 'inspect' from '/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/inspect.py'>...
        Serializing 'RESULT_DUPLICATE' __duplicate__...
    Detected 3 nonlocal variables. Checking serializability...
        Serializing 'train_func' <function with_parameters.<locals>._inner at 0x7fbe100bd3a0>...
        !!! FAIL serialization: cannot pickle 'Event' object
        Detected 1 nonlocal variables. Checking serializability...
            Serializing 'inner' <function with_parameters.<locals>.inner at 0x7fbe100bd280>...
            !!! FAIL serialization: cannot pickle 'Event' object
    Serializing '_annotated' FunctionTrainable...
================================================================================
Variable: 

        FailTuple(inner [obj=<function with_parameters.<locals>.inner at 0x7fbe100bd280>, parent=<function with_parameters.<locals>._inner at 0x7fbe100bd3a0>])

was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. 
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. 
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
================================================================================
Traceback (most recent call last):
  File "/visinf/home/shamidi/Projects/rainbow-memory-Bayesian/main.py", line 347, in <module>
    main()
  File "/visinf/home/shamidi/Projects/rainbow-memory-Bayesian/main.py", line 223, in main
    result = tune.run(
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/ray/tune/tune.py", line 520, in run
    experiments[i] = Experiment(
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/ray/tune/experiment/experiment.py", line 163, in __init__
    self._run_identifier = Experiment.register_if_needed(run)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/ray/tune/experiment/experiment.py", line 365, in register_if_needed
    raise type(e)(str(e) + " " + extra_msg) from None
TypeError: cannot pickle 'Event' object Other options: 
-Try reproducing the issue by calling `pickle.dumps(trainable)`. 
-If the error is typing-related, try removing the type annotations and try again

my code follows the steps below:

for cur_iter in range(args.n_tasks):
        
        
      
        if args.mode == "joint" and cur_iter > 0:
            return

        print("\n" + "#" * 50)
        print(f"# Task {cur_iter} iteration")
        print("#" * 50 + "\n")
        logger.info("[2-1] Prepare a datalist for the current task")

        task_acc = 0.0
        eval_dict = dict()

        # get datalist
        cur_train_datalist = get_train_datalist(args, cur_iter)
        cur_valid_datalist = get_valid_datalist(args, args.exp_name, cur_iter)
        cur_test_datalist = get_test_datalist(args, args.exp_name, cur_iter)

        logger.info("[2-2] Set environment for the current task")

        method.set_current_dataset(cur_train_datalist, cur_test_datalist, cur_valid_datalist)
    
        method.before_task(cur_train_datalist, cur_iter, args.init_model, args.init_opt, 
                           args.bayesian_model)

        # The way to handle streamed samples
        logger.info(f"[2-3] Start to train under {args.stream_env}")
        
        if args.stream_env == "offline" or args.mode == "joint" or args.mode == "gdumb":
            # Offline Train
       
            # -----------------------------------------------------------------------------------------------------------------
            # Ray Tune for the first task of the blurry case
            # -----------------------------------------------------------------------------------------------------------------
            if args.exp_name == "blurry10" and cur_iter==0:
                # configs has already been defined.
                
                configs={"lr": tune.loguniform(1e-4, 1e-1), "weight_decay":tune.uniform(1e-8, 1e-1)}
                hyperopt_search = HyperOptSearch(metric='accuracy', mode='max')
                #hyperopt_search = BayesOptSearch(metric='loss', mode='min',points_to_evaluat[{"lamda": 1}, {"lamda": 25}]
                scheduler = ASHAScheduler(
                    metric="accuracy",
                    mode="max",
                    max_t=100,
                    grace_period=5,
                    reduction_factor=2)
                
                reporter = CLIReporter(
                    parameter_columns=["lr", "wd"],
                    metric_columns=["loss", "accuracy", "training_iteration"]
                    )
               
                
               
                result = tune.run(
                                tune.with_parameters(method.find_hyperparametrs),
                                resources_per_trial={"cpu": 1, "gpu": 1},
                                config=configs,
                                num_samples=1,
                                search_alg=hyperopt_search,
                                scheduler=scheduler,
                                #keep_checkpoints_num=2,
                                checkpoint_score_attr="accuracy", 
                                progress_reporter=reporter
                                )

and the set_current_dataset() is:

def set_current_dataset(self, train_datalist, test_datalist, valid_datalist):
        
        random.shuffle(train_datalist)
        self.prev_streamed_list = self.streamed_list
        self.streamed_list = train_datalist
        self.test_list = test_datalist
        # add validation set
        self.valid_list = valid_datalist

        # For ray tune - test
        self.train_loader, self.test_loader, self.valid_loader  = self.get_dataloader(
            self.batch_size, self.n_worker, train_list = random.shuffle(self.streamed_list), 
                            test_list=self.test_list, valid_list=self.valid_list)
def get_dataloader(self, batch_size, n_worker, train_list, test_list, valid_list):
        # Loader
        train_loader = None
        test_loader = None
        # add valid_loader 
        valid_loader = None

        if train_list is not None and len(train_list) > 0:
            train_dataset = ImageDataset(
                pd.DataFrame(train_list),
                dataset=self.dataset,
                transform=self.train_transform,
            )
            # drop last becasue of BatchNorm1D in IcarlNet
            train_loader = DataLoader(
                train_dataset,
                shuffle=True,
                batch_size=batch_size,
                num_workers=n_worker,
                drop_last=True,
                pin_memory=True,
            )

        if test_list is not None:
            test_dataset = ImageDataset(
                pd.DataFrame(test_list),
                dataset=self.dataset,
                transform=self.test_transform,
            )
            test_loader = DataLoader(
                test_dataset, shuffle=False, batch_size=batch_size, num_workers=n_worker, pin_memory=True
            )
       
        if valid_list is not None:
            valid_dataset = ImageDataset(
                pd.DataFrame(valid_list),
                dataset=self.dataset,
                transform=self.test_transform, # use the same transformation for the valid set as the test set
            )
            valid_loader = DataLoader(
                valid_dataset, shuffle=False, batch_size=batch_size, num_workers=n_worker, pin_memory=True
            )

        return train_loader, test_loader, valid_loader

and, finally, the trainable (I am not sure if this is the correct term) is as follows:

class RM(Finetune, tune.Trainable):
    def __init__(
        self, criterion, device, train_transform, test_transform, n_classes, **kwargs
    ):
        super().__init__(
            criterion, device, train_transform, test_transform, n_classes, **kwargs
        )
        
        self.batch_size = kwargs["batchsize"]
        self.n_worker = kwargs["n_worker"]
        self.exp_env = kwargs["stream_env"]
        self.bayesian = kwargs["bayesian_model"]
        self.pretrain = kwargs['pretrain']
        self.scheduler_name = kwargs["sched_name"]
        if kwargs["mem_manage"] == "default":
            self.mem_manage = "uncertainty"

   # --------------------------------------------------------------------------------------------------
   # For Ray Tune
   # --------------------------------------------------------------------------------------------------
    def find_hyperparametrs(self, config):
      

        #batch_size = self.batch_size
        n_worker = self.n_workers
        cur_iter = 0

        self.optimizer = select_optimizer(self.opt_name, config['lr'], config['weight_decay'], self.model, self.sched_name)

        # TRAIN
        eval_dict = dict()
        self.model = self.model.to(self.device)
        
        for epoch in range(self.n_epochs):
           
            # initialize for each task
            # optimizer.param_groups is a python list, which contains a dictionary.
            if self.scheduler_name == "cos":
                if epoch <= 0:  # Warm start of 1 epoch
                    for param_group in self.optimizer.param_groups:
                        # param_group is the dict inside the list and is the only item in this list.
                        if self.bayesian is True:
                            param_group["lr"] = self.lr *0.1  # self.lr * 0.1   this was changed due to inf error
                        else:
                            param_group["lr"] = self.lr * 0.1
                elif epoch == 1:  # Then set to maxlr
                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = self.lr
                else:  # Aand go!
                    if self.scheduler is not None:
                        self.scheduler.step()
            else:
                if self.scheduler is not None:
                    self.scheduler.step()

            # Training
            train_loss, train_acc = self._train(train_loader=self.train_loader, memory_loader=None,
                                                optimizer=self.optimizer, criterion=self.criterion)
            
            # Validation (validating over all the test sets seen so far)
            eval_dict_valid = self.evaluation(
                self.valid_loader, criterion=self.criterion
            )

            # Communicate with Ray tune
            with tune.checkpoint_dir(epoch) as checkpoint_dir: # what should be the checkpoint_dir will be?
                path = os.path.join(checkpoint_dir, "ray_checkpoints", "checkpoint")
                torch.save((self.model.state_dict(), self.optimizer.state_dict()), path)

            tune.report(
                loss=eval_dict_valid["avg_loss"], accuracy=eval_dict_valid["avg_acc"]
                )
            

            # Testing(testing over all the test sets seen so far)
            eval_dict = self.evaluation(
                test_loader=self.test_loader, criterion=self.criterion
            )
            
            # Report the results on the current epoch
            logger.info(
                f"Task {cur_iter} | Epoch {epoch+1}/{self.n_epochs} | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | "
                f"test_loss {eval_dict['avg_loss']:.4f} | test_acc {eval_dict['avg_acc']:.4f} | "
                f"valid_loss {eval_dict_valid['avg_loss']:.4f} | valid_acc {eval_dict_valid['avg_acc']:.4f} | "
                f"lr {self.optimizer.param_groups[0]['lr']:.4f}"
            )

def update_model(self, x, y, criterion, optimizer):
        # chekc the label type, output of the bayesian model
        
        optimizer.zero_grad()
        do_cutmix = self.cutmix and np.random.rand(1) < 0.5
        if do_cutmix:
            x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)
            '''
            x = x.double()
            labels_a = labels_a.double()
            labels_b = labels_b.double()
            '''
            # take care of the output of the bayesian model and its probabilistic loss
            if self.bayesian:
                #self.model.double()
                logit_dict = self.model(x)

                loss = lam * criterion(logit_dict, labels_a)['total_loss'] + (1 - lam) * criterion(
                    logit_dict, labels_b)['total_loss']
                #loss = losses_dict['total_loss']
                logit = criterion(logit_dict, labels_a)['prediction']
                logit = logit.mean(dim=2)
            else:
                #self.model.double()
                logit = self.model(x)
                loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(
                    logit, labels_b
                )
        else:
            
            if self.bayesian:
                # measure forward pass time
                #t_start = time.time()
                #self.model.double()
                logit_dict = self.model(x)
                #t_end = time.time() - t_start
                # logger.info(f'forward pass time: {t_end:.2f} s')

                # criterion is the probabilistic loss class
                #t_s = time.time()
                losses_dict = criterion(logit_dict, y)
                #t_e = time.time() - t_s
                #logger.info(f'loss time: {t_e:.2f} s')
                
                loss = losses_dict['total_loss']
                logit = losses_dict['prediction'] # Shape: torch.Size([10, 10, 64]) --> (batch_size, num_classes, samples)
                # change the shape of the logit to be (batch_size, num_classes)
                logit = logit.mean(dim=2)
            else:
                #self.model.double()
                logit = self.model(x)
                loss = criterion(logit, y)
        
        # calculate the number of correct predictions per batch for the bayesian model as well here
        _, preds = logit.topk(self.topk, 1, True, True)

        loss.backward()
        ''' ToDo: is it necessary to clip the gradient? it was done in mnvi code
        Maybe they didn't need it but I'm not sure. For the Bayesian case, it is probably needed.
        '''
        if self.bayesian:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1, norm_type='inf')
        
        optimizer.step()
        return loss.item(), torch.sum(preds == y.unsqueeze(1)).item(), y.size(0)

    def _train(
        self, train_loader, memory_loader, optimizer, criterion
    ):
        
        total_loss, correct, num_data = 0.0, 0.0, 0.0

        self.model.train()
        if memory_loader is not None and train_loader is not None:
            data_iterator = zip(train_loader, cycle(memory_loader))
        elif memory_loader is not None:
            data_iterator = memory_loader
        elif train_loader is not None:
            data_iterator = train_loader
        else:
            raise NotImplementedError("None of dataloder is valid")
        
        for i, data in enumerate(data_iterator):
            if len(data) == 2:
                stream_data, mem_data = data
                x = torch.cat([stream_data["image"], mem_data["image"]])
                y = torch.cat([stream_data["label"], mem_data["label"]])
            else:
                x = data["image"]
                y = data["label"]
            # set to double
            #x = x.double().to(self.device)
            #y = y.double().to(self.device)

            x = x.to(self.device)
            y = y.to(self.device)

            '''
            all_model, _ = self.measure_time(self.model, x)
            print('all_model', all_model)
            '''
            # measure each operation time of the forward pass for one batch
            # ---------------------------------------------------
           
            # ------------------------------------------------------
            # this is equivalent to the step code in the test repo
            l, c, d = self.update_model(x, y, criterion, optimizer)
            # Compute the moving averages - equivalent to MovingAverage in the test repo
            total_loss += l
            correct += c
            num_data += d

        if train_loader is not None:
            n_batches = len(train_loader)
        else:
            n_batches = len(memory_loader)

        return total_loss / n_batches, correct / num_data

    def allocate_batch_size(self, n_old_class, n_new_class):
        new_batch_size = int(
            self.batch_size * n_new_class / (n_old_class + n_new_class)
        )
        old_batch_size = self.batch_size - new_batch_size
        return new_batch_size, old_batch_size
            

I am not sure which object here is causing the problem and how to go about solving it. I removed the datasets out of the trainable function but that didn’t help. I also tried changing the version of ray.tune, going from 2.4.0 to 2.0 but that didn’t help. Any guidance is very much appreciated.