Train-Test Split with Stratification
Advanced data splitting with stratification, validation set, and feature scaling.
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
def split_data(X, y, test_size=0.2, val_size=0.1, random_state=42,
stratify=None, scale_features=False):
"""
Split data into train, validation, and test sets with optional scaling.
Args:
X: Feature matrix
y: Target vector
test_size: Proportion of test set
val_size: Proportion of validation set (from remaining after test split)
random_state: Random seed
stratify: Whether to stratify (pass y for classification)
scale_features: Whether to scale features
Returns:
Split datasets and optional scaler
"""
# First split: train+val and test
X_temp, X_test, y_temp, y_test = train_test_split(
X, y,
test_size=test_size,
random_state=random_state,
stratify=stratify
)
# Second split: train and validation
# Adjust val_size relative to remaining data
val_size_adjusted = val_size / (1 - test_size)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp,
test_size=val_size_adjusted,
random_state=random_state,
stratify=y_temp if stratify is not None else None
)
# Scale features if requested
scaler = None
if scale_features:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)
# Print split information
print("=" * 60)
print("DATA SPLIT SUMMARY")
print("=" * 60)
print(f"Training set: {X_train.shape[0]:,} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
print(f"Validation set: {X_val.shape[0]:,} samples ({X_val.shape[0]/len(X)*100:.1f}%)")
print(f"Test set: {X_test.shape[0]:,} samples ({X_test.shape[0]/len(X)*100:.1f}%)")
if isinstance(y, pd.Series) or isinstance(y, np.ndarray):
if len(np.unique(y)) < 10: # Classification
print("\nClass distribution:")
for split_name, split_y in [("Train", y_train),
("Val", y_val),
("Test", y_test)]:
unique, counts = np.unique(split_y, return_counts=True)
print(f" {split_name}: {dict(zip(unique, counts))}")
return {
'X_train': X_train, 'X_val': X_val, 'X_test': X_test,
'y_train': y_train, 'y_val': y_val, 'y_test': y_test,
'scaler': scaler
}
# Usage Example
# splits = split_data(
# X, y,
# test_size=0.2,
# val_size=0.1,
# stratify=y, # For classification
# scale_features=True
# )