Feature Importance Visualization

Visualize and analyze feature importance from tree-based models

feature-importancevisualizationmachine-learningpython

Feature Importance Visualization

Visualize feature importance from tree-based models with interactive plots.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_feature_importance(model, feature_names, top_n=20, figsize=(10, 8)):
    """
    Plot feature importance from tree-based models.

    Args:
        model: Trained model with feature_importances_ attribute
        feature_names: List of feature names
        top_n: Number of top features to display
        figsize: Figure size
    """
    # Get feature importance
    importance = pd.DataFrame({
        'feature': feature_names,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)

    # Select top N features
    top_features = importance.head(top_n)

    # Create plot
    plt.figure(figsize=figsize)
    plt.barh(range(len(top_features)), top_features['importance'],
             color='steelblue', alpha=0.7)
    plt.yticks(range(len(top_features)), top_features['feature'])
    plt.xlabel('Importance')
    plt.title(f'Top {top_n} Feature Importances', fontsize=14, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.grid(axis='x', alpha=0.3, linestyle='--')
    plt.tight_layout()
    plt.show()

    # Print statistics
    print("=" * 60)
    print("FEATURE IMPORTANCE STATISTICS")
    print("=" * 60)
    print(f"Total features: {len(importance)}")
    print(f"Top {top_n} features account for {top_features['importance'].sum():.2%} of total importance")
    print(f"\nTop {top_n} Features:")
    for idx, row in top_features.iterrows():
        print(f"  {row['feature']}: {row['importance']:.4f}")

    return importance

def plot_feature_importance_comparison(models_dict, feature_names, top_n=15):
    """
    Compare feature importance across multiple models.

    Args:
        models_dict: Dictionary of {model_name: model} pairs
        feature_names: List of feature names
        top_n: Number of top features to compare
    """
    # Get importance for each model
    importance_df = pd.DataFrame(index=feature_names)

    for model_name, model in models_dict.items():
        importance_df[model_name] = model.feature_importances_

    # Get top features based on average importance
    avg_importance = importance_df.mean(axis=1).sort_values(ascending=False)
    top_features = avg_importance.head(top_n).index

    # Plot comparison
    plt.figure(figsize=(12, 8))
    importance_df.loc[top_features].plot(kind='barh', figsize=(12, 8))
    plt.xlabel('Importance')
    plt.title(f'Feature Importance Comparison (Top {top_n})',
              fontsize=14, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.legend(title='Models', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(axis='x', alpha=0.3, linestyle='--')
    plt.tight_layout()
    plt.show()

    return importance_df

# Usage Example
# importance = plot_feature_importance(
#     model, feature_names, top_n=15
# )