State_dict() and CUDA OOM

Hi,

I’m training a model on 4xGPU A100 80GB and it trains fine, but is really occupying nearly all the VRAM. Now, when I want to save the model, I cannot directly use torch.save(model) because of the ProcessGroup, so I use torch.save(model.state_dict()); but this creates an OOM crash, because it allocates new VRAM for the state_dict: this seems rather useless!
Do you know of any way to save the state_dict without allocating any new VRAM? Thanks a lot!

Christophe

What is your host RAM used for? Maybe you could delete unneeded data before moving the state_dict to the host.

Thank you! It’s the classical weights + gradients + activations (w/ gradient checkpointing) + optimizer state (but I use SGD), and I want to checkpoint the model during training. But there’s not enough VRAM to hold 2 times the weights in bf16 (I’m using sharding for training) on a single GPU, so I’m not sure it’s possible to do that… That’s why I’d like to directly copy the weights from the GPU to the disk, or through the CPU, without duplicating the weights on the GPU… Thanks a lot for your ideas, which are always super-useful :slight_smile:

OK but I see your point: move the weights onto the CPU RAM, and then call state_dict() and save it… But I guess calling state_dict() will also duplicate the weights, and I’m afraid the CPU RAM cannot host 2 times the weights, but I’ll have to do some maths to know… My point was that it seems pretty useless to duplicate the weights just to save them, I don’t really understand why it works like that?

I don’t fully understand where the duplication would come from. Calling model.state_dict() will return references to the parameters and buffers, so no copy will be created unless you explicitly call copy.deepcopy on it. Moving the model from the GPU to the CPU will also just move the data without creating duplicates.

Thanks for the explanation; but if state_dict() does not allocate any memory, I don’t understand why I have an OOM error when I call it, and I don’t otherwise. Here is the error trace to show where it allocates memory, may it be useful? Thanks a lot!

P.S. On 2nd thoughts, it might be because when asking to save the model, it’s unsharding it, so trying to copy back all the other GPU shards into the main one, and it OOM because it still has other gradients, etc. in VRAM… So this would mean that I should only save the model after releasing all gradients, etc… !?!

β”‚   209 β”‚   β”‚   if rank==0:                                                    β”‚
β”‚   210 β”‚   β”‚   β”‚   dt = time.time() - t0                                      β”‚
β”‚   211 β”‚   β”‚   β”‚   if dt>1800: # 30 minutes                                   β”‚
β”‚ ❱ 212 β”‚   β”‚   β”‚   β”‚   states = model.state_dict()                            β”‚
β”‚   213 β”‚   β”‚   β”‚   β”‚   torch.save(states, wd2+"fsdpmod."+str(it0)+".pt")      β”‚
β”‚   214 β”‚   β”‚   β”‚   β”‚   t0 = time.time()                                       β”‚
β”‚   215 β”‚   β”‚   β”‚   β”‚   it0 += 1                                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1818 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚ ❱ 1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚   1819 β”‚   β”‚   for hook in self._state_dict_hooks.values():                  β”‚
β”‚   1820 β”‚   β”‚   β”‚   hook_result = hook(self, destination, prefix, local_metad β”‚
β”‚   1821 β”‚   β”‚   β”‚   if hook_result is not None:                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1815 in state_dict        β”‚
β”‚                                                                              β”‚
β”‚   1812 β”‚   β”‚   if hasattr(destination, "_metadata"):                         β”‚
β”‚   1813 β”‚   β”‚   β”‚   destination._metadata[prefix[:-1]] = local_metadata       β”‚
β”‚   1814 β”‚   β”‚                                                                 β”‚
β”‚ ❱ 1815 β”‚   β”‚   self._save_to_state_dict(destination, prefix, keep_vars)      β”‚
β”‚   1816 β”‚   β”‚   for name, module in self._modules.items():                    β”‚
β”‚   1817 β”‚   β”‚   β”‚   if module is not None:                                    β”‚
β”‚   1818 β”‚   β”‚   β”‚   β”‚   module.state_dict(destination=destination, prefix=pre β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/nn/modules/module.py:1722 in                   β”‚
β”‚ _save_to_state_dict                                                          β”‚
β”‚                                                                              β”‚
β”‚   1719 β”‚   β”‚   β”‚   β”‚   module                                                β”‚
β”‚   1720 β”‚   β”‚   """                                                           β”‚
β”‚   1721 β”‚   β”‚   for hook in self._state_dict_pre_hooks.values():              β”‚
β”‚ ❱ 1722 β”‚   β”‚   β”‚   hook(self, prefix, keep_vars)                             β”‚
β”‚   1723 β”‚   β”‚                                                                 β”‚
β”‚   1724 β”‚   β”‚   for name, param in self._parameters.items():                  β”‚
β”‚   1725 β”‚   β”‚   β”‚   if param is not None:                                     β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/utils/_contextlib.py:115 in decorate_context   β”‚
β”‚                                                                              β”‚
β”‚   112 β”‚   @functools.wraps(func)                                             β”‚
β”‚   113 β”‚   def decorate_context(*args, **kwargs):                             β”‚
β”‚   114 β”‚   β”‚   with ctx_factory():                                            β”‚
β”‚ ❱ 115 β”‚   β”‚   β”‚   return func(*args, **kwargs)                               β”‚
β”‚   116 β”‚                                                                      β”‚
β”‚   117 β”‚   return decorate_context                                            β”‚
β”‚   118                                                                        β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:669 in   β”‚
β”‚ _pre_state_dict_hook                                                         β”‚
β”‚                                                                              β”‚
β”‚   666 β”‚   β”‚   StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,    β”‚
β”‚   667 β”‚   β”‚   StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook β”‚
β”‚   668 β”‚   }                                                                  β”‚
β”‚ ❱ 669 β”‚   _pre_state_dict_hook_fn[fsdp_state._state_dict_type](              β”‚
β”‚   670 β”‚   β”‚   fsdp_state,                                                    β”‚
β”‚   671 β”‚   β”‚   module,                                                        β”‚
β”‚   672 β”‚   β”‚   *args,                                                         β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:271 in   β”‚
β”‚ _full_pre_state_dict_hook                                                    β”‚
β”‚                                                                              β”‚
β”‚   268 β”‚   in ``nn.Module``.                                                  β”‚
β”‚   269 β”‚   """                                                                β”‚
β”‚   270 β”‚   _common_pre_state_dict_hook(module, fsdp_state)                    β”‚
β”‚ ❱ 271 β”‚   _common_unshard_pre_state_dict_hook(                               β”‚
β”‚   272 β”‚   β”‚   module,                                                        β”‚
β”‚   273 β”‚   β”‚   fsdp_state,                                                    β”‚
β”‚   274 β”‚   β”‚   offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,   β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:143 in   β”‚
β”‚ _common_unshard_pre_state_dict_hook                                          β”‚
β”‚                                                                              β”‚
β”‚   140 β”‚   Performs the pre-state_dict tasks shared by all state_dict types t β”‚
β”‚   141 β”‚   ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STAT β”‚
β”‚   142 β”‚   """                                                                β”‚
β”‚ ❱ 143 β”‚   _enter_unshard_params_ctx(                                         β”‚
β”‚   144 β”‚   β”‚   module,                                                        β”‚
β”‚   145 β”‚   β”‚   fsdp_state,                                                    β”‚
β”‚   146 β”‚   β”‚   writeback=False,                                               β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:109 in   β”‚
β”‚ _enter_unshard_params_ctx                                                    β”‚
β”‚                                                                              β”‚
β”‚   106 β”‚   β”‚   offload_to_cpu=offload_to_cpu,                                 β”‚
β”‚   107 β”‚   β”‚   with_grads=with_grads,                                         β”‚
β”‚   108 β”‚   )                                                                  β”‚
β”‚ ❱ 109 β”‚   fsdp_state._unshard_params_ctx[module].__enter__()                 β”‚
β”‚   110                                                                        β”‚
β”‚   111                                                                        β”‚
β”‚   112 @no_type_check                                                         β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/contextlib.py:135 in __enter__                                     β”‚
β”‚                                                                              β”‚
β”‚   132 β”‚   β”‚   # they are only needed for recreation, which is not possible a β”‚
β”‚   133 β”‚   β”‚   del self.args, self.kwds, self.func                            β”‚
β”‚   134 β”‚   β”‚   try:                                                           β”‚
β”‚ ❱ 135 β”‚   β”‚   β”‚   return next(self.gen)                                      β”‚
β”‚   136 β”‚   β”‚   except StopIteration:                                          β”‚
β”‚   137 β”‚   β”‚   β”‚   raise RuntimeError("generator didn't yield") from None     β”‚
β”‚   138                                                                        β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_unshard_param_utils.py:198   β”‚
β”‚ in _unshard_fsdp_state_params                                                β”‚
β”‚                                                                              β”‚
β”‚   195 β”‚   # No need to call `wait_stream()` since we unshard in the computat β”‚
β”‚   196 β”‚   # stream directly                                                  β”‚
β”‚   197 β”‚   computation_stream = torch.cuda.current_stream()                   β”‚
β”‚ ❱ 198 β”‚   _unshard(state, handles, computation_stream, computation_stream)   β”‚
β”‚   199 β”‚   if with_grads:                                                     β”‚
β”‚   200 β”‚   β”‚   _unshard_grads(handles)                                        β”‚
β”‚   201                                                                        β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py:329 in      β”‚
β”‚ _unshard                                                                     β”‚
β”‚                                                                              β”‚
β”‚    326 β”‚   β”‚   β”‚   event.synchronize()                                       β”‚
β”‚    327 β”‚   with torch.cuda.stream(unshard_stream):                           β”‚
β”‚    328 β”‚   β”‚   for handle in handles:                                        β”‚
β”‚ ❱  329 β”‚   β”‚   β”‚   handle.unshard()                                          β”‚
β”‚    330 β”‚   β”‚   β”‚   handle.post_unshard()                                     β”‚
β”‚    331                                                                       β”‚
β”‚    332                                                                       β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/flat_param.py:918 in unshard  β”‚
β”‚                                                                              β”‚
β”‚    915 β”‚   β”‚   β”‚   )                                                         β”‚
β”‚    916 β”‚   β”‚   β”‚   self._use_unsharded_flat_param(unsharded_flat_param)      β”‚
β”‚    917 β”‚   β”‚   β”‚   return                                                    β”‚
β”‚ ❱  918 β”‚   β”‚   unsharded_flat_param = self._alloc_padded_unsharded_flat_para β”‚
β”‚    919 β”‚   β”‚   padded_unsharded_flat_param = self._all_gather_flat_param(uns β”‚
β”‚    920 β”‚   β”‚   self._use_unsharded_flat_param(padded_unsharded_flat_param)   β”‚
β”‚    921                                                                       β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/flat_param.py:944 in          β”‚
β”‚ _alloc_padded_unsharded_flat_param                                           β”‚
β”‚                                                                              β”‚
β”‚    941 β”‚   β”‚   flat_param = self.flat_param                                  β”‚
β”‚    942 β”‚   β”‚   unsharded_flat_param = self._get_padded_unsharded_flat_param( β”‚
β”‚    943 β”‚   β”‚   self._check_storage_freed(unsharded_flat_param)               β”‚
β”‚ ❱  944 β”‚   β”‚   _alloc_storage(unsharded_flat_param, flat_param._padded_unsha β”‚
β”‚    945 β”‚   β”‚   return unsharded_flat_param                                   β”‚
β”‚    946 β”‚                                                                     β”‚
β”‚    947 β”‚   def _get_padded_unsharded_flat_param(self) -> torch.Tensor:       β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/utils/_contextlib.py:115 in decorate_context   β”‚
β”‚                                                                              β”‚
β”‚   112 β”‚   @functools.wraps(func)                                             β”‚
β”‚   113 β”‚   def decorate_context(*args, **kwargs):                             β”‚
β”‚   114 β”‚   β”‚   with ctx_factory():                                            β”‚
β”‚ ❱ 115 β”‚   β”‚   β”‚   return func(*args, **kwargs)                               β”‚
β”‚   116 β”‚                                                                      β”‚
β”‚   117 β”‚   return decorate_context                                            β”‚
β”‚   118                                                                        β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/distributed/fsdp/_utils.py:79 in               β”‚
β”‚ _alloc_storage                                                               β”‚
β”‚                                                                              β”‚
β”‚    76 β”‚   β”‚   β”‚   tensor_storage_size == 0,                                  β”‚
β”‚    77 β”‚   β”‚   β”‚   f"Tensor storage should have been resized to be 0 but got  β”‚
β”‚    78 β”‚   β”‚   )                                                              β”‚
β”‚ ❱  79 β”‚   β”‚   tensor._typed_storage()._resize_(size.numel())                 β”‚
β”‚    80 β”‚   return not already_allocated                                       β”‚
β”‚    81                                                                        β”‚
β”‚    82                                                                        β”‚
β”‚                                                                              β”‚
β”‚ /gpfslocalsup/pub/anaconda-py3/2023.03/envs/pytorch-gpu-2.0.0+py3.10.9/lib/p β”‚
β”‚ ython3.10/site-packages/torch/storage.py:764 in _resize_                     β”‚
β”‚                                                                              β”‚
β”‚    761 β”‚                                                                     β”‚
β”‚    762 β”‚   # For internal use only, to avoid deprecation warning             β”‚
β”‚    763 β”‚   def _resize_(self, size):                                         β”‚
β”‚ ❱  764 β”‚   β”‚   self._untyped_storage.resize_(size * self._element_size())    β”‚
β”‚    765 β”‚                                                                     β”‚
β”‚    766 β”‚   @classmethod                                                      β”‚
β”‚    767 β”‚   def _free_weak_ref(cls, *args, **kwargs):                         β”‚
╰──────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 236.00 MiB (GPU 0; 79.15
GiB total capacity; 76.94 GiB already allocated; 121.19 MiB free; 77.82 GiB 
reserved in total by PyTorch) If reserved memory is >> allocated memory try 
setting max_split_size_mb to avoid fragmentation.  See documentation for Memory 
Management and PYTORCH_CUDA_ALLOC_CONF

Yes, it doesn’t as also seen here:

model = models.resnet152().cuda()
print("{:.2f}MB".format(torch.cuda.memory_allocated()/1024**2))
# 230.27MB

sd = model.state_dict()
print("{:.2f}MB".format(torch.cuda.memory_allocated()/1024**2))
# 230.28MB

clone = copy.deepcopy(sd)
print("{:.2f}MB".format(torch.cuda.memory_allocated()/1024**2))
# 462.24MB

That might be the case as I didn’t know you are using FSDP.
I believe the recommendation is to turn on cpu_offload in torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type, which should avoid the CUDA OOM.

1 Like

Thanks a lot for the link, that’s exactly what I need! Thanks! :smiley:

1 Like