Feature Importance Visualization
Visualize feature importance from tree-based models with interactive plots.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_feature_importance(model, feature_names, top_n=20, figsize=(10, 8)):
"""
Plot feature importance from tree-based models.
Args:
model: Trained model with feature_importances_ attribute
feature_names: List of feature names
top_n: Number of top features to display
figsize: Figure size
"""
# Get feature importance
importance = pd.DataFrame({
'feature': feature_names,
'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
# Select top N features
top_features = importance.head(top_n)
# Create plot
plt.figure(figsize=figsize)
plt.barh(range(len(top_features)), top_features['importance'],
color='steelblue', alpha=0.7)
plt.yticks(range(len(top_features)), top_features['feature'])
plt.xlabel('Importance')
plt.title(f'Top {top_n} Feature Importances', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3, linestyle='--')
plt.tight_layout()
plt.show()
# Print statistics
print("=" * 60)
print("FEATURE IMPORTANCE STATISTICS")
print("=" * 60)
print(f"Total features: {len(importance)}")
print(f"Top {top_n} features account for {top_features['importance'].sum():.2%} of total importance")
print(f"\nTop {top_n} Features:")
for idx, row in top_features.iterrows():
print(f" {row['feature']}: {row['importance']:.4f}")
return importance
def plot_feature_importance_comparison(models_dict, feature_names, top_n=15):
"""
Compare feature importance across multiple models.
Args:
models_dict: Dictionary of {model_name: model} pairs
feature_names: List of feature names
top_n: Number of top features to compare
"""
# Get importance for each model
importance_df = pd.DataFrame(index=feature_names)
for model_name, model in models_dict.items():
importance_df[model_name] = model.feature_importances_
# Get top features based on average importance
avg_importance = importance_df.mean(axis=1).sort_values(ascending=False)
top_features = avg_importance.head(top_n).index
# Plot comparison
plt.figure(figsize=(12, 8))
importance_df.loc[top_features].plot(kind='barh', figsize=(12, 8))
plt.xlabel('Importance')
plt.title(f'Feature Importance Comparison (Top {top_n})',
fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.legend(title='Models', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='x', alpha=0.3, linestyle='--')
plt.tight_layout()
plt.show()
return importance_df
# Usage Example
# importance = plot_feature_importance(
# model, feature_names, top_n=15
# )