{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "8qPBlJDHwucz" }, "source": [ "Análisis de errores en conjunto de validación\n", "=============================================" ] }, { "cell_type": "markdown", "metadata": { "id": "PM02vWO4wuc5" }, "source": [ "## Introducción\n", "\n", "El análisis de errores es el proceso para identificar, observar y diagnosticar predicciones erróneas de un modelo de aprendizaje automático, ayudandonos a comprender las áreas con fortalezas o debilidades de un modelo. Cuando decimos que \"la precisión del modelo es del 90%\", puede que no sea uniforme en todos los subgrupos de datos o puede haber algunas condiciones en los datos de entrada en las que el modelo falla más. Por lo tanto, es importante someter las métricas a una revisión más profunda para poder mejorarlo.\n", "\n", "En este ejemplo veremos como utilizar la herramienta de Error Analysis provista en el Responsable AI Toolbox. Para mas detalles sobre esta herramienta visite: https://github.com/microsoft/responsible-ai-toolbox." ] }, { "cell_type": "markdown", "metadata": { "id": "PjoSE6qawuc6" }, "source": [ "## Utilizando el análisis de errores en el problema censo de la UCI" ] }, { "cell_type": "markdown", "metadata": { "id": "5oEclTmhwuc7" }, "source": [ "> Nota: Este ejemplo fué adaptado de https://github.com/microsoft/responsible-ai-toolbox/blob/main/notebooks/individual-dashboards/erroranalysis-dashboard/erroranalysis-interpretability-dashboard-census.ipynb." ] }, { "cell_type": "markdown", "metadata": { "id": "CWmaFc37wuc8" }, "source": [ "### Instalación\n", "\n", "Utilizaremos las librerías `interpret-community`, `raiwidgets` y `error-analysis`." ] }, { "cell_type": "markdown", "metadata": { "id": "A5m-yUvIwuc9" }, "source": [ "Para ejecutar este ejemplo, necesitaremos instalar las librerias `interpret-community`, `raiwidgets` y `error-analysis` y `lightgbm`:" ] }, { "cell_type": "code", "source": [ "!pip install rai-core-flask raiutils --quiet\n", "!pip install statsmodels==0.14.5 numpy==1.26.2 erroranalysis==0.5.5 responsibleai==0.36.0 interpret-community==0.32.0 interpret-core==0.6.6 lightgbm raiwidgets==0.36.0 scikit-learn==1.5.1 shap==0.46.0 fairlearn interpret ml_wrappers econml sparse semver dice_ml --no-deps --quiet\n" ], "metadata": { "id": "eJZPOFzKoxTT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "SCYoWY5twudC" }, "source": [ "> **IMPORTANTE:** Reinicie el kernel desde el menu `Runtime` (Google Colab) o `Kernel` (Jupyter) luego de realizar la instalación. Es posible que se mencionen errores de resolución de paquetes durante la instalación. Puede ignorarlos." ] }, { "cell_type": "markdown", "metadata": { "id": "CyWscZCQwudE" }, "source": [ "### Sobre el conjunto de datos del censo UCI\n", "\n", "El conjunto de datos del censo de la UCI es un conjunto de datos en el que cada registro representa a una persona. Cada registro contiene 14 columnas que describen a una una sola persona, de la base de datos del censo de Estados Unidos de 1994. Esto incluye información como la edad, el estado civil y el nivel educativo. La tarea es determinar si una persona tiene un ingreso alto (definido como ganar más de $50 mil al año). Esta tarea, dado el tipo de datos que utiliza, se usa a menudo en el estudio de equidad, en parte debido a los atributos comprensibles del conjunto de datos, incluidos algunos que contienen tipos sensibles como la edad y el género, y en parte también porque comprende una tarea claramente del mundo real." ] }, { "cell_type": "markdown", "metadata": { "id": "OnHZuGLSwudG" }, "source": [ "Descargamos el conjunto de datos" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "PcnArrfQwudH" }, "outputs": [], "source": [ "!wget https://santiagxf.blob.core.windows.net/public/datasets/uci_census.zip \\\n", " --quiet --no-clobber\n", "!mkdir -p datasets/uci_census\n", "!unzip -qq uci_census.zip -d datasets/uci_census" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZO-pj3WwudI" }, "source": [ "Lo importamos" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "sn4kODTMwudI" }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "train = pd.read_csv('datasets/uci_census/data/adult-train.csv')\n", "test = pd.read_csv('datasets/uci_census/data/adult-test.csv')" ] }, { "cell_type": "markdown", "metadata": { "id": "MMG5UhjewudK" }, "source": [ "### Entrenando un modelo para explorar" ] }, { "cell_type": "markdown", "metadata": { "id": "gdNPtv-KwudK" }, "source": [ "Preparando nuestros conjuntos de datos" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "EhLsyUphwudL" }, "outputs": [], "source": [ "X_train = train.drop(['income'], axis=1)\n", "y_train = train['income'].to_numpy()\n", "X_test = test.drop(['income'], axis=1)\n", "y_test = test['income'].to_numpy()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "oZiFns7SwudL" }, "outputs": [], "source": [ "classes = train['income'].unique().tolist()\n", "features = X_train.columns.values.tolist()\n", "categorical_features = X_train.dtypes[X_train.dtypes == 'object'].index.tolist()" ] }, { "cell_type": "markdown", "metadata": { "id": "6HzsWIV1wudM" }, "source": [ "Realizaremos un pequeño preprocesamiento antes de entrenar el modelo:\n", "\n", "- Imputaremos los valores faltantes de las caracteristicas numéricas con la media\n", "- Imputaremos los valores faltantes de las caracteristicas categóricas con el valor `?`\n", "- Escalaremos los valores numericos utilizando un `StandardScaler`\n", "- Codificaremos las variables categóricas utilizando `OneHotEncoder`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "374RsZu0wudM" }, "outputs": [], "source": [ "from typing import Tuple, List\n", "\n", "import sklearn\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", "\n", "def prepare(X: pd.DataFrame) -> Tuple[np.ndarray, sklearn.compose.ColumnTransformer]:\n", " pipe_cfg = {\n", " 'num_cols': X.dtypes[X.dtypes == 'int64'].index.values.tolist(),\n", " 'cat_cols': X.dtypes[X.dtypes == 'object'].index.values.tolist(),\n", " }\n", "\n", " num_pipe = Pipeline([\n", " ('num_imputer', SimpleImputer(strategy='median')),\n", " ('num_scaler', StandardScaler())\n", " ])\n", "\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse_output=False))\n", " ])\n", "\n", " transformations = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n", " ('cat_pipe', cat_pipe, pipe_cfg['cat_cols'])\n", " ])\n", " X = transformations.fit_transform(X)\n", "\n", " return X, transformations\n", "\n", "\n", "X_train_transformed, transformations = prepare(X_train)\n", "X_test_transformed = transformations.transform(X_test)" ] }, { "cell_type": "markdown", "metadata": { "id": "81atQ7WUwudN" }, "source": [ "Entrenamos un modelo basado en `lightgbm`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "u2MVdIhKwudN", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7663d606-fa3f-4e39-b7fc-0aa60e3b837d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[LightGBM] [Info] Number of positive: 7841, number of negative: 24720\n", "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.009690 seconds.\n", "You can set `force_row_wise=true` to remove the overhead.\n", "And if memory is not enough, you can set `force_col_wise=true`.\n", "[LightGBM] [Info] Total Bins 781\n", "[LightGBM] [Info] Number of data points in the train set: 32561, number of used features: 95\n", "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.240810 -> initscore=-1.148246\n", "[LightGBM] [Info] Start training from score -1.148246\n" ] } ], "source": [ "from lightgbm import LGBMClassifier\n", "\n", "clf = LGBMClassifier(n_estimators=5)\n", "model = clf.fit(X_train_transformed, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "F62qFkc2wudO" }, "source": [ "Ejecutamos el modelo" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "ZLqpnS_rwudO" }, "outputs": [], "source": [ "predictions = model.predict(X_test_transformed)" ] }, { "cell_type": "markdown", "source": [ "Podemos revisar la performance del modelo:" ], "metadata": { "id": "Lwr57cbqJRuN" } }, { "cell_type": "code", "source": [ "from sklearn.metrics import classification_report\n", "\n", "print(classification_report(y_test, predictions))" ], "metadata": { "id": "kMNf-9RkJTme", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e4c5efd6-802e-467c-c67c-06f2794c8bf4" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " precision recall f1-score support\n", "\n", " <=50K 0.81 1.00 0.90 12435\n", " >50K 0.98 0.25 0.40 3846\n", "\n", " accuracy 0.82 16281\n", " macro avg 0.90 0.63 0.65 16281\n", "weighted avg 0.85 0.82 0.78 16281\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "> Tip: En este ejemplo, hemos entrenado un modelo con un numero de estimadores pequeño a proposito para obtener un modelo con sezgo." ], "metadata": { "id": "uAiI9Fg-9vgs" } }, { "cell_type": "markdown", "metadata": { "id": "weAeZyDcwudO" }, "source": [ "### Análisis de errores" ] }, { "cell_type": "markdown", "metadata": { "id": "yjxkgFAFwudP" }, "source": [ "El análisis de errores nos permite explorar como se distributye el error de las predicciones de nuestro modelo. Una de las ventajas mas interesantes de esta herramienta es que es agnostica del modelo, es decir, que lo podemos aplicar para cualquier tipo de modelo ya que solamente necesitamos proveer los conjuntos de datos de evaluación con las predicciones que realizó el modelo." ] }, { "cell_type": "markdown", "source": [ "> Opcionalmente, podemos aumentar la cantidad de muestras en el conjunto de datos de validación utilizando la técnica de oversampling. Esto nos permite que la herramienta de análisis de errores tenga mas instancias para trabajar y por ende generar más opciones para visualización.\n", ">\n", "> ```python\n", "> test = test.sample(1000, replace=True)\n", "> ```\n", ">\n", "> En tal caso, recuerde ejecutar las prediciones sobre el conjunto de datos con *oversampling*:\n", ">\n", "> ```python\n", "> X_test = test.drop(['income'], axis=1)\n", "> y_test = test['income'].to_numpy()\n", "> X_test_transformed = transformations.transform(X_test)\n", "> predictions = model.predict(X_test_transformed)\n", "> ```" ], "metadata": { "id": "NsH3LDqpJtyz" } }, { "cell_type": "markdown", "source": [ "Para abrir la herramienta, debemos construir un `ErrorAnalysisDashbord`:" ], "metadata": { "id": "n9diai9wKWBk" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RJ2VV6kVwudP" }, "outputs": [], "source": [ "from raiwidgets import ErrorAnalysisDashboard\n", "from interpret_community.common.constants import ModelTask\n", "\n", "ErrorAnalysisDashboard(dataset=X_test,\n", " true_y=y_test,\n", " categorical_features=categorical_features,\n", " features=features,\n", " pred_y=predictions,\n", " model_task=ModelTask.Classification)" ] }, { "cell_type": "markdown", "metadata": { "id": "f8hkuo7swudP" }, "source": [ "*Notas:*\n", "\n", " - *`dataset` es el conjunto de datos de evaluación, sin preprocesamiento y sin la variable objetivo.*\n", " - *`true_y` es el valor verdadero (ground-truth) de la variable a predecir*\n", " - *`pred_y` es el valor de las predicciones del modelo que estamos evaluando*.\n", " - *`categorical_features`: son los nombres de las columnas que tienen variables de tipo categorica.*\n", " - *`features` son los los nombres de todas las columnas que utiliza el modelo, categoricas y numeras incluidas.*\n", " - *`model_task` es el tipo de modelo que construirmos, donde los valores possibles son `ModelTask.Classification` o `ModelTask.Regression`.*" ] }, { "cell_type": "markdown", "source": [ "> __¿Ve el tablero en blanco?__: Si ejecuta en Google Colab, asegurese de permitir conexiones sobre http (contenido mixto).\n", ">\n", ">\n", "> ![](https://user-images.githubusercontent.com/15119906/110703410-091f9480-81f4-11eb-9858-cf45a9da71cf.png)\n", ">\n", "> ![](https://user-images.githubusercontent.com/15119906/110703555-39ffc980-81f4-11eb-88db-19abb0caf2ce.png)\n", ">\n", "> __¿Sigue sin ver el tablero?__: Asegurese de que su conjunto de datos está bien configurado, que no tiene valores faltantes en las columas y que configuró correctamente las variables categoricas." ], "metadata": { "id": "e35stQM12U-Q" } }, { "cell_type": "markdown", "metadata": { "id": "HfmfZLNowudP" }, "source": [ "Una vez ejecutado el comando anterior, debería poder ver la herramienta de exloración de errores:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "t1ppwKPrwudQ" }, "source": [ "#### Interpretación\n", "\n", "Podemos utilizar este gráfico para explorar la forma en que el modelo comente los errores. Para encontrar patrones, podemos ocmenzar buscando aquellos nodos en el arbol que tienen un color rojo más fuerte, lo que indica que esa combinación de atributos tiene un error alto al clasificarlos. El nivel de llenado del nodo indica que tan representativa es esa combinación de atributos en el conjunto de datos completo\n", "\n", "Esto quiere decir que si nos focalizamos en aquellos nodos con color más oscuro y nivel más alto, estamos atacando aquellas áreas donde tenemos más chances de mejorar la performance del modelo. Por ejemplo, en la imagen anterior vemos que cuando la relación es `Husband` o `Wife` y la cantidad de años de educación es mayor a 11.5 pero el capital es menor a $ 5035.50, estas instancias tienen una taza de error del 65% y representan el 35% de todos los errores que comete el modelo.\n", "\n", "Deberiamos investigar porque nuestro modelo no puede mapear a este tipo de instancias correctamente. Quizás haya un probema en la calidad de datos, una mala recolección, o quizás las características no fueron preprocesadas correctamente." ] }, { "cell_type": "markdown", "metadata": { "id": "2-2m9ga2wudQ" }, "source": [ "#### Instancias con dificultades\n", "\n", "Una característica interesante de esta libreria es la capacidad de generar mapas de calor con aquellas combinaciones de atributos donde nuestro modelo tiene problemas. Esto nos permite ver rapidamente donde el modelo tiene inconvenientes en predecir correctamente y desde allí, analizár si el modelo es aceptable cometiendo estos errores o no:\n", "\n", "\n", "\n", "En el ejemplo más arriba estamos comparando los predictores `relationship` y `education-num`. Como vemos, el modelo tiene grandes problemas con aquellas personas de más de 14 años de educación y que son mujeres casadas. Solo 1 persona fué clasificada correctamente representando una taza de error del 94%." ] }, { "cell_type": "markdown", "metadata": { "id": "MT-GIKb9wudQ" }, "source": [ "### Explicaciones\n", "\n", "> IMPORTANTE: Si ejecuto el tablero anterior, deberá reiniciar la sesión de Colab. Esto se debe solo a restricciones de Colab. En ambientes productivos no debería realizar esta tarea." ] }, { "cell_type": "markdown", "metadata": { "id": "6A3XRTjEwudQ" }, "source": [ "Las explicaciones del modelo nos pueden ser útiles a la hora de explorar la importancia de cada uno de los atributos y como son utilizados por el modelo. Para generar las explicaciones del modelo, deberemos constuir un pipeline donde tengamos el preprocesamiento y el modelo propiamente dicho en un mismo objeto ya que las técnicas de explicaciones contemplan tanto las instancias de preprocesamiento como de modelado:" ] }, { "cell_type": "markdown", "metadata": { "id": "5zgRDRwEwudQ" }, "source": [ "> **IMPORTANTE:** Note que esta técnica requiere proveer el modelo original que genera la predicciones, y por lo tanto no podrá utilizarlo en escenarios donde el modelo fué entrenado en otra herramienta." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TMAE_OZBwudR" }, "outputs": [], "source": [ "model_pipeline = Pipeline(steps=[('preprocessing', transformations),\n", " ('model', model)])" ] }, { "cell_type": "markdown", "metadata": { "id": "HTbdBbqIwudR" }, "source": [ "Configuramos un objeto para general las explicaciones del modelo basado en el conjunto de datos en el que se entreno:\n", "\n", "> **IMPORTANTE:** Note que esta técnica require la creación de un modelo que sea interpretable. Es decir, en lugar de realizar el análisis en el modelo original (el cual podría tener una complejidad arbitraria), el análsis se hace sobre un modelo que pueda ser facilmente interpretable. Para generar este segundo modelo, se utiliza la técnica de Global Model Surrogate la cual consiste en entrenar un modelo **alumno** que trata de **imitar** al modelo **profesor**. Esta técnica claramente no es exacta y no tenemos ninguna garantía de que los errores que comete el modelo **alumno** son los mismos que los que comete el alumno **profesor**." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QTb_QcznwudR", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "80b4d310-c994-4898-a159-e6e6c60183a1" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.075441 seconds.\n", "You can set `force_col_wise=true` to remove the overhead.\n", "[LightGBM] [Info] Total Bins 781\n", "[LightGBM] [Info] Number of data points in the train set: 32561, number of used features: 95\n", "[LightGBM] [Info] Start training from score -1.218004\n" ] } ], "source": [ "from interpret_community.common.constants import ShapValuesOutput, ModelTask\n", "from interpret_community.mimic import MimicExplainer\n", "from interpret_community.mimic.models import LGBMExplainableModel\n", "\n", "\n", "explainer = MimicExplainer(model=model,\n", " initialization_examples=X_train,\n", " explainable_model=LGBMExplainableModel,\n", " augment_data=True,\n", " max_num_of_augmentations=10,\n", " features=features,\n", " classes=classes,\n", " model_task=ModelTask.Classification,\n", " transformations=transformations)" ] }, { "cell_type": "markdown", "metadata": { "id": "dG_fVjkpwudS" }, "source": [ "Generamos las explicaciones del modelo en nuestro conjunto de validación" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tmOAH_VjwudT", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "749e0098-bc70-40b2-9f8b-fb5c8c5d42a0" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n" ] } ], "source": [ "global_explanation = explainer.explain_global(X_test)" ] }, { "cell_type": "code", "source": [ "pd.DataFrame(\n", " data={\n", " \"importancia\": global_explanation.get_ranked_global_values(),\n", " \"feature\": global_explanation.get_ranked_global_names()\n", " },\n", " index=global_explanation.global_importance_rank\n", ")" ], "metadata": { "id": "WqVPOc0MVQey", "colab": { "base_uri": "https://localhost:8080/", "height": 488 }, "outputId": "e299f1b3-b162-4edb-bfc2-2a3c0a7a7e6a" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " importancia feature\n", "5 0.360204 marital-status\n", "4 0.170727 education-num\n", "10 0.095830 capital-gain\n", "0 0.072773 age\n", "12 0.050053 hours-per-week\n", "11 0.028772 capital-loss\n", "6 0.017707 occupation\n", "1 0.005299 workclass\n", "2 0.001785 fnlwgt\n", "7 0.001371 relationship\n", "3 0.000489 education\n", "9 0.000050 gender\n", "8 0.000004 race\n", "13 0.000004 native-country" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
importanciafeature
50.360204marital-status
40.170727education-num
100.095830capital-gain
00.072773age
120.050053hours-per-week
110.028772capital-loss
60.017707occupation
10.005299workclass
20.001785fnlwgt
70.001371relationship
30.000489education
90.000050gender
80.000004race
130.000004native-country
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "summary": "{\n \"name\": \")\",\n \"rows\": 14,\n \"fields\": [\n {\n \"column\": \"importancia\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.10029144188831822,\n \"min\": 3.6212851354707033e-06,\n \"max\": 0.36020407376757096,\n \"num_unique_values\": 14,\n \"samples\": [\n 0.0013711002186458843,\n 4.999638518712062e-05,\n 0.36020407376757096\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"feature\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 14,\n \"samples\": [\n \"relationship\",\n \"gender\",\n \"marital-status\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "markdown", "source": [ "Podemos verificar que tan equivalente es el modelo surrogado:" ], "metadata": { "id": "-gKBKksWXCaE" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rtl7ajPkwudT", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "a2de6945-49b8-463e-cab7-303f0c460744" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n", "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "X does not have valid feature names, but SimpleImputer was fitted with feature names\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n", "Changing the sparsity structure of a csr_matrix is expensive. lil and dok are more efficient.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "0.9998464420625902" ] }, "metadata": {}, "execution_count": 14 } ], "source": [ "explainer.get_surrogate_model_replication_measure(training_data=X_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "DS13UFTFwudT" }, "source": [ "Mostramos el tablero:" ] }, { "cell_type": "code", "source": [ "from raiwidgets import ErrorAnalysisDashboard\n", "from interpret_community.common.constants import ModelTask\n", "\n", "ErrorAnalysisDashboard(explanation=global_explanation,\n", " dataset=X_test,\n", " true_y=y_test,\n", " categorical_features=categorical_features,\n", " features=features,\n", " pred_y=predictions,\n", " model_task=ModelTask.Classification)" ], "metadata": { "id": "YAUJzX2yTenD" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "2TzcDdGqwudT" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "1eB3RiDlwudU" }, "source": [ "### Feature Impotance\n", "\n", "En la parte superior del gráfico vemos la importancia de cada uno de los predictores que se utilizaron en el modelo. En este caso, nos indica que `marital-status` y `education` son dos de los predictores más importantes para nuestro modelo. El concepto de aque de \"importancia\" esta atado a que, cuando estos valores cambian, la performance del modelo decrece significativamente.\n", "\n", "### Dependency plots\n", "\n", "En la parte inferior vemos como varia la importancia del predictor `ocupation` dependiendo de cada uno de todos los varios que obtiene. Este tipo de gráficos se llaman \"Dependency plots\" y nos permiten ver como la importancia de la clase que queremos predecir cambia a medida que cambian los valores de los predictores.\n", "\n", "Por ejemplo, notemos como la importancia de la variable `occupation` es muy baja cuando cuando el valor es `priv-house-serv`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.12 ('sphinx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "vscode": { "interpreter": { "hash": "c0c26a04c01997af4d3a54c44ba2029caf4208eaf3de13f3aa81bddca06af044" } }, "colab": { "provenance": [], "toc_visible": true } }, "nbformat": 4, "nbformat_minor": 0 }