Train-Test Split with Stratification

Advanced train-test splitting with stratification and validation set

train-test-splitdata-splittingsklearnpython

Train-Test Split with Stratification

Advanced data splitting with stratification, validation set, and feature scaling.

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np

def split_data(X, y, test_size=0.2, val_size=0.1, random_state=42,
               stratify=None, scale_features=False):
    """
    Split data into train, validation, and test sets with optional scaling.

    Args:
        X: Feature matrix
        y: Target vector
        test_size: Proportion of test set
        val_size: Proportion of validation set (from remaining after test split)
        random_state: Random seed
        stratify: Whether to stratify (pass y for classification)
        scale_features: Whether to scale features

    Returns:
        Split datasets and optional scaler
    """
    # First split: train+val and test
    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=random_state,
        stratify=stratify
    )

    # Second split: train and validation
    # Adjust val_size relative to remaining data
    val_size_adjusted = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp,
        test_size=val_size_adjusted,
        random_state=random_state,
        stratify=y_temp if stratify is not None else None
    )

    # Scale features if requested
    scaler = None
    if scale_features:
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_val = scaler.transform(X_val)
        X_test = scaler.transform(X_test)

    # Print split information
    print("=" * 60)
    print("DATA SPLIT SUMMARY")
    print("=" * 60)
    print(f"Training set:   {X_train.shape[0]:,} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
    print(f"Validation set: {X_val.shape[0]:,} samples ({X_val.shape[0]/len(X)*100:.1f}%)")
    print(f"Test set:       {X_test.shape[0]:,} samples ({X_test.shape[0]/len(X)*100:.1f}%)")

    if isinstance(y, pd.Series) or isinstance(y, np.ndarray):
        if len(np.unique(y)) < 10:  # Classification
            print("\nClass distribution:")
            for split_name, split_y in [("Train", y_train),
                                        ("Val", y_val),
                                        ("Test", y_test)]:
                unique, counts = np.unique(split_y, return_counts=True)
                print(f"  {split_name}: {dict(zip(unique, counts))}")

    return {
        'X_train': X_train, 'X_val': X_val, 'X_test': X_test,
        'y_train': y_train, 'y_val': y_val, 'y_test': y_test,
        'scaler': scaler
    }

# Usage Example
# splits = split_data(
#     X, y,
#     test_size=0.2,
#     val_size=0.1,
#     stratify=y,  # For classification
#     scale_features=True
# )