Problem with loading multiple data files using iterableDataset

Hi all, I am new to PyTorch and I have problems with loading multiple data files using IterableDataset.
I have a about 10 .txt files and each file is about 3G, ~3 million rows. Each files include both features and target columns. I want to use the files to train a simple NN. I defined my dataset as below:

#Dataset for multiple smaller files
class MultiFileDataset(torch.utils.data.IterableDataset):
    def __init__(self, file_paths, feature_cols, target_cols, scaler, sample_frac=1):
        self.file_paths = file_paths
        self.feature_cols = feature_cols
        self.target_cols = target_cols
        self.scaler = scaler 
        self.frac = sample_frac

    def __iter__(self):
        for file_path in self.file_paths:
            chunk = pd.read_csv(file_path, usecols= self.feature_cols+self.target_cols, sep=',', header=0)
            chunk = chunk[(chunk[self.target_cols] > 0).all(axis=1)]
            chunk.dropna(inplace=True)

            # Shuffle the chunk to simulate shuffling
            chunk = chunk.sample(frac=self.frac).reset_index(drop=True)
            
            X_chunk = chunk[self.feature_cols].values
            y_chunk = chunk[self.target_cols].values

            # Scale the features
            X_chunk_scaled = self.scaler.transform(X_chunk)

            # Convert to PyTorch tensors
            X_tensor = torch.tensor(X_chunk_scaled, dtype=torch.float32)
            y_tensor = torch.tensor(y_chunk, dtype=torch.float32)

            # Yield each sample
            for X, y in zip(X_tensor, y_tensor):
                yield X, y

def fit_scaler(file_paths, feature_cols, target_cols):
    scaler = StandardScaler()
    total_rows = 0
    for file_path in file_paths:
        chunk = pd.read_csv(file_path, usecols= feature_cols+target_cols, sep=',', header=0)
        chunk = chunk[(chunk[target_cols] > 0).all(axis=1)]
        chunk.dropna(inplace=True)
        X_chunk = chunk[feature_cols].values
        # Accumulate the partial fit
        scaler.partial_fit(X_chunk)
        total_rows += len(X_chunk)
    print(f"Scaler fitted on {total_rows} rows.")
    return scaler

# define the feature and target columns
feature_cols = ['fov_num', 'asc_flag']+ [f'ch_{i+1}_limb' for i in range(22)]
target_cols = [f"t_{pressure_level[i]:.0f}hPa" for i in range(len(pressures))] + ["skt", "t2m"]
# File paths
parent_dir = os.path.dirname(os.getcwd())
folder_path = os.path.join(parent_dir, 'data')
file_pattern = 'atms_era5_limb_2023aug1_10_*.txt'
file_paths =glob.glob(os.path.join(folder_path, file_pattern))
# Create DataLoader
multi_file_dataset = MultiFileDataset(file_paths=file_paths, feature_cols=feature_cols, 
                                      target_cols=target_cols, scaler=scaler, sample_frac=0.7)
train_loader = DataLoader(multi_file_dataset, batch_size=100000)


# training
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    hyperparams: Dict[str, int]
) -> List[float]:
    model.train()
    train_losses = []

    for epoch in range(hyperparams['num_epochs']):
        epoch_loss = 0.0
        sample_count = 0  # Track the number of samples
        for batch_X, batch_y in train_loader:             
            batch_X = batch_X.to(torch.device("cuda"))
            batch_y = batch_y.to(torch.device("cuda"))
          
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            rmse_loss = torch.sqrt(loss)
            batch_size = batch_X.size(0)  # Get the current batch size
            epoch_loss += rmse_loss.item() * batch_size
            sample_count += batch_size  # Accumulate the number of samples
            
        # Use the accumulated sample count for averaging
        avg_loss = epoch_loss / sample_count
        train_losses.append(avg_loss)
        print(f'Epoch [{epoch + 1}/{hyperparams["num_epochs"]}], RMSE: {avg_loss:.7f}')

    return train_losses

# Initialize the model with the determined input and output sizes
model = MLP(
    input_size=hyperparams['input_size'],  # Replace with the appropriate input size
    hidden_size1=hyperparams['hidden_size1'],
    hidden_size2=hyperparams['hidden_size2'],
    output_size=hyperparams['output_size']  # Adjust according to your output requirements
)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
model.to(torch.device("cuda"))

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer1 = torch.optim.Adam(
    model.parameters(),
    lr=hyperparams['learning_rate'],
    betas=(hyperparams['beta1'], hyperparams['beta2']),
    eps=hyperparams['epsilon'],
    weight_decay=hyperparams['weight_decay']
)
train_losses = train_model(model, train_loader, criterion, optimizer1, hyperparams)

However, the RMSE doesn’t decrease after every epoch and the problem is most likely with the data loading. So I change definition of dataset using torch.utils.data.Dataset

class newMultiFileDataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, feature_cols, target_cols, scaler):
        """
        Standard Dataset for multiple files with pre-scaling.
        """
        self.file_paths = file_paths
        self.feature_cols = feature_cols
        self.target_cols = target_cols
        self.scaler = scaler
        self.data = []

        # Load all data into memory and scale once
        for file_path in file_paths:
            chunk = pd.read_csv(file_path)
            chunk = chunk[(chunk[self.target_cols] > 0).all(axis=1)]
            chunk.dropna(inplace=True)

            # Extract features and targets
            X_chunk = chunk[self.feature_cols].values
            y_chunk = chunk[self.target_cols].values

            # Scale features once during initialization
            X_chunk_scaled = self.scaler.transform(X_chunk)

            # Store the scaled data
            self.data.extend(list(zip(X_chunk_scaled, y_chunk)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, target = self.data[idx]
        # Convert to tensors
        features_tensor = torch.tensor(features, dtype=torch.float32)
        target_tensor = torch.tensor(target, dtype=torch.float32)
        return features_tensor, target_tensor

multi_file_dataset = myMultiFileDataset(file_paths, feature_cols, target_cols, scaler)
train_loader = DataLoader(multi_file_dataset,  batch_size=hyperparams['batch_size'], shuffle=True)

And this time, the RMSE decreases properly after each epoch, but for the second solution, I only try with 2 files because I don’t want to load all data into the memory at once. I am wondering what’s wrong with my first MultiFileDataset? And if I want to train large data on multiple files, what is best way to do so? Thank you in advance.