XAI en Producción: Haz que tus Modelos Black-Box Expliquen sus Decisiones con Menos de 200 Líneas de Código

La Inteligencia Artificial Explicable (XAI) se ha vuelto omnipresente en conferencias tecnológicas, descripciones de puestos y propuestas de proyectos. Sin embargo, para muchos, XAI evoca imágenes de mapas de calor, curvas de ajuste local de LIME o gráficos de cascada de valores SHAP, que, aunque profetionales, resultan inútiles cuando un líder de producto pregunta: "¿Por qué se rechazó la solicitud de crédito de este usuario?". Los mapas de calor no se traducen fácilmente en informes semanales ni explican las decisiones a los departamentos legal, de cumplimiento o incluso a los clientes.

Con siete años de experiencia en la implementación de XAI, desde scripts de Python personalizados para calcular la contribución de características hasta la integración de bibliotecas de código abierto para canalizaciones de explicación completas, y en el último año, liderando la entrega de 11 módulos de XAI auditables y rastreables en finanzas, manufactura y gobierno, estoy convencido de que: XAI no es una característica adicional, sino un estándar de entrega de modelos; no es una herramienta de exhibición para científicos de datos, sino un traductor entre modelos, negocios, reguladores y usuarios. Esta nota se centra en un solo objetivo: cómo, en proyectos del mundo real, convertir un modelo "black-box" en uno que pueda explicar sus "porqués" utilizando menos de 200 líneas de código central. Es para científicos de datos novatos que, tras ejecutar su primer modelo, se enfrentan a preguntas desconcertantes; para líderes de ingeniería que luchan con las aprobaciones de cumplimiento para la implementación de modelos; y para gerentes de producto de datos que interactúan constantemente con "informes de explicabilidad" que a menudo son incomprensibles.

Por Qué SHAP o LIME por Sí Solos No Son Suficientes: Una Lección de Producción

Muchos equipos, al abordar XAI por primera vez, recurren a la biblioteca shap y ejecutan TreeExplainer, generando barras de gráficos como entrega. He visto demasiados casos: un banco urbano implementa un modelo de puntuación de crédito, y SHAP destaca "ingresos", "ratio de endeudamiento" y "número de impagos históricos" como las características más importantes. El departamento de negocios aprueba. Tres meses después, durante una auditoría regulatoria, se solicita una explicación para la "decisión de rechazo del cliente de alto riesgo A". El equipo muestra el gráfico SHAP, señalando la entrada "número de impagos históricos = 3", diciendo: "Esto hizo que la puntuación fuera baja". El regulador pregunta: "Entonces, ¿por qué se aprobó al cliente B si sus impagos históricos también fueron 3? ¿Y por qué se rechazó al cliente C con solo 2 impagos? ¿Cómo '3' activa el umbral? ¿Cuál es la lógica interna del modelo?". El silencio reina. El problema no es la imprecisión de SHAP, sino que solo responde a "qué característica es importante", evitando "cómo el modelo utiliza esa característica para juzgar".

Esta es la debilidad fatal de las herramientas de explicación puramente globales/locales en entornos de producción: no modelan la lógica de decisión, solo el efecto de la perturbación de las características. SHAP te dice "cuánto aumentaría la puntuación si cambiara los impagos históricos del cliente A de 3 a 0", pero no te dice "si el modelo considera 'impagos=3' como una línea roja dura". LIME es aún más propenso a errores, ya que utiliza un modelo simple para aproximar la respuesta del black-box localmente, lo que esencialmente es una "aproximación local". Si una muestra cae en el borde de la distribución de entrenamiento, el error de aproximación puede ser catastrófico.

Al reconstruir nuestra solución, dividimos los requisitos de XAI en tres niveles:

  • Nivel 1: Trazabilidad: Debe ser posible rastrear hasta la línea de código específica, el nodo del modelo y el cálculo que produjo la salida final. Esto requiere que el explicador sea determinista y que todas las variables intermedias sean serializables.
  • Nivel 2: Verificabilidad: Los resultados de la explicación deben poder verificarse mediante lógica independiente. Por ejemplo, si la explicación dice "rechazado debido a impagos > 2", entonces construir artificialmente una muestra con impagos=3 y alimentarla al modelo debe resultar consistentemente en un rechazo; luego, cambiar los impagos a 2 debe resultar consistentemente en una aprobación. Este proceso de verificación debe ser automatizable.
  • Nivel 3: Comunicabilidad: La salida de la explicación debe poder traducirse al lenguaje empresarial. No "valor SHAP = -0.42", sino "este cliente fue rechazado debido a 3 registros de impago en los últimos 6 meses, activando la regla de control de riesgos R7-2 (ver Sección 4.2 del Manual de Aprobación de Crédito)".

Basándonos en estos tres niveles, descartamos las herramientas "listas para usar" únicas y construimos un marco de explicación ligero, compuesto principalmente por tres partes:

  1. Extractor de Reglas: Para modelos basados en árboles (XGBoost/LightGBM), utiliza sklearn.tree.export_text para analizar inversamente las rutas de división de cada árbol, fusionando reglas comunes para generar un conjunto de reglas comerciales en formato IF-THEN.
  2. Analizador de Perturbaciones: Evitando el muestreo aleatorio de LIME, diseña perturbaciones dirigidas basadas en la semántica empresarial. Por ejemplo, perturba la "edad" solo en ±5 años y los "ingresos" en ±20%, asegurando que cada perturbación esté dentro de un rango empresarial razonable.
  3. Rastreador de Atribución: Durante la propagación hacia adelante del modelo, utiliza ganchos de torch.autograd (PyTorch) o tf.GradientTape (TF) para registrar la contribución del gradiente de cada característica a los logits finales. Combinado con los valores reales de las características, calcula una puntuación de atribución ponderada, evitando el sesgo subjetivo de la selección de "valores base" en SHAP.

Esta combinación no es para presumir, sino porque: el extractor de reglas satisface la "verificabilidad", el analizador de perturbaciones satisface la "trazabilidad" y el rastreador de atribución satisface la "comunicabilidad". Las salidas de los tres pueden cruzarse para verificar: si las reglas dicen "rechazar si impagos > 2", pero el rastreador de atribución muestra que la puntuación del cliente se ve afectada principalmente por la "antigüedad laboral", entonces las reglas no cubren todos los casos y se necesitan más datos para reentrenar.
Nota: No te obsesiones con las herramientas "SOTA" (State-Of-The-Art). Probé Captum (la biblioteca oficial de XAI de PyTorch) en un proyecto de inspección de calidad industrial. Su atribución para mapas de características CNN fue impresionante, pero los trabajadores de la línea de producción no entendían los mapas de calor. Al final, eliminamos toda la visualización y nos quedamos con una frase: "La ubicación del defecto está en el área de 12 cm x 8 cm en la esquina inferior derecha de la imagen. En comparación con la pieza estándar, el valor de la escala de grises aquí se desvía de la media en 3.7σ, superando el umbral del proceso de 3.2σ". Esta frase se incrustó directamente en el PDF del informe de inspección, y el supervisor de la línea de producción la entendió de un vistazo.

Detalles Clave: 7 Puntos Prácticos desde la Carga del Modelo hasta la Generación de Explicaciones

Los puntos más problemáticos en la implementación de XAI a menudo se esconden en los pasos aparentemente más simples. Los siguientes 7 puntos son normas estrictas que hemos adoptado después de tropezar al menos tres veces y que ahora están incluidas en nuestro SOP del equipo. Cada uno viene con fragmentos de código reales y notas para evitar problemas.

3.1 El Formato de Guardado del Modelo Determina el Límite de la Explicación: Por Qué .pkl Es Inferior a .onnx

Muchos equipos usan joblib.dump(model, 'model.pkl') para guardar modelos, facilitando la carga posterior. Sin embargo, los explicadores de XAI necesitan acceder a la estructura interna del modelo, como los umbrales de división de los árboles, las matrices de pesos de las redes neuronales e incluso los tipos de funciones de activación. .pkl es una serialización específica de Python, propensa a errores entre versiones y entornos, y no puede ser llamada por sistemas no Python (como motores de control de riesgos escritos en Java). Exigimos: todos los modelos implementados deben exportarse en formato ONNX.

# ✅ Correcto: Exportar a ONNX, la estructura está completamente expuesta
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# Asumiendo que 'model' es un XGBoostClassifier entrenado
initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))]
onx = convert_sklearn(model, initial_types=initial_type)
with open("model.onnx", "wb") as f:
    f.write(onx.SerializeToString())

# ❌ Incorrecto: El archivo .pkl no puede ser analizado por el extractor de reglas para obtener la estructura interna del árbol
# joblib.dump(model, 'model.pkl') # El explicador lo lee y lanza AttributeError: 'Booster' object has no attribute 'tree_'

La ventaja de ONNX es que define el modelo como un grafo computacional (graph), donde cada nodo indica explícitamente el tipo de operador (TreeEnsembleClassifier), los nombres de entrada/salida y todos los parámetros (como nodes_falsenodeids, nodes_featureids). Nuestro extractor de reglas analiza directamente el grafo ONNX, recorriendo los nodos capa por capa y traduciendo las condiciones de división de cada nodo TreeEnsembleClassifier en sentencias IF. La ventaja de esto es el desacoplamiento de la lógica de explicación del entorno de implementación del modelo: los modelos ONNX se pueden ejecutar en Python, C++, Java o incluso en navegadores, y el explicador solo necesita leer el mismo archivo ONNX para producir resultados consistentes.

Nota: XGBoost/LightGBM requieren conversión con skl2onnx; no se puede usar directamente xgb.to_onnx() (la API anterior no admite la exportación de la estructura completa del árbol). Hemos probado que skl2onnx v1.7+ tiene el soporte más estable para LightGBM, y v1.6 tiene problemas de compatibilidad con objective='binary:logistic' de XGBoost, que requiere actualización.

3.2 El Preprocesamiento de Características Debe Ser Inmutable: Por Qué StandardScaler No Debe Hacer fit en Tiempo de Explicación

La precisión de la explicación depende en gran medida de la "limpieza" de las características de entrada. Sin embargo, muchos desarrolladores vuelven a ejecutar fit_transform en un StandardScaler durante la fase de explicación, lo que provoca que los parámetros de escalado de características (media, varianza) difieran de los utilizados durante el entrenamiento, distorsionando completamente los resultados de atribución.

# ❌ Error Fatal: Volver a hacer fit en tiempo de explicación rompe la consistencia
scaler = StandardScaler()
X_explain_scaled = scaler.fit_transform(X_single_sample) # ¡fit usa solo 1 muestra! media=valor de la muestra, varianza=0

# ✅ Correcto: Guardar el scaler durante el entrenamiento, solo hacer transform en la explicación
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # fit usa todo el conjunto de entrenamiento
# ... entrenar modelo ...
joblib.dump(scaler, 'scaler.pkl') # Inmutabilizar el preprocesador

# Fase de explicación:
scaler = joblib.load('scaler.pkl')
X_single_scaled = scaler.transform(X_single_sample.reshape(1, -1)) # Solo hacer transform

Además, exigimos que todos los pasos de preprocesamiento (imputación de valores faltantes, codificación categórica, binning) se empaqueten en un Pipeline y se exporten juntos con sklearn2onnx. De esta manera, el archivo ONNX no solo contiene el modelo, sino también toda la cadena de preprocesamiento. El explicador llama directamente al runtime de ONNX para ejecutar la inferencia de extremo a extremo, eliminando por completo la inconsistencia "entrenamiento-explicación".

3.3 Control de Umbrales para la Extracción de Reglas: Cómo Evitar Generar 100,000 Sentencias IF Inválidas

Los modelos de árboles pueden tener cientos de árboles, cada uno con más de 10 niveles de profundidad. Desplegar todas las rutas generaría un número astronómico de reglas. Adoptamos un filtrado de tres niveles:

  1. Filtrado por Cobertura: Solo conservamos rutas que cubran ≥0.1% de las muestras de entrenamiento (correspondiente al campo nodes_samples en el nodo ONNX);
  2. Filtrado por Confianza: Solo conservamos reglas cuya probabilidad de predicción del nodo hoja sea ≥0.8 (campo class_weights);
  3. Filtrado por Relevancia Empresarial: Configuramos manualmente una lista blanca de características, como overdue_count, income para escenarios financieros; temperature, vibration_rms para escenarios de manufactura. Las rutas de otras características se descartan directamente.
# Pseudocódigo para fusión de reglas (versión simplificada)
def merge_rules(onnx_graph, whitelist_features):
    all_paths = extract_all_paths(onnx_graph) # Analizar ONNX para obtener todas las rutas de división
    filtered_paths = []
    for path in all_paths:
        if (path.coverage >= 0.001 and 
            path.confidence >= 0.8 and
            any(f in path.features for f in whitelist_features)):
            filtered_paths.append(path)
    
    # Fusionar rutas similares: Por ejemplo, Ruta 1: SI ingresos>5000 Y impagos=0 → aprobar
    #                        Ruta 2: SI ingresos>6000 Y impagos=0 → aprobar
    # Fusionar en: SI ingresos>5000 Y impagos=0 → aprobar
    merged_rules = merge_similar_paths(filtered_paths)
    return merged_rules

En pruebas prácticas con un modelo de crédito, se obtuvieron 127,432 rutas originales. Después del filtrado de tres niveles, quedaron 37 reglas centrales. De estas, 21 fueron confirmadas por el departamento de negocios como "reglas de efecto real", con una precisión del 100%. Estas 37 reglas generaron directamente un documento de Word, incrustado en el "Manual del Modelo", sirviendo como evidencia directa para la auditoría de cumplimiento.

3.4 Diseño de Semántica Empresarial para el Análisis de Perturbaciones: Por Qué la Perturbación Aleatoria es una Falsa Premisa

LIME perturba las características con ruido gaussiano, lo cual es matemáticamente elegante pero empresarialmente absurdo. Por ejemplo, añadir ruido a "estado civil" podría dar como resultado marital_status=1.37, que el modelo no puede procesar. Debemos diseñar perturbaciones según el dominio empresarial:

Tipo de Característica Método de Perturbación Ejemplo Significado Empresarial
Numérica (Continua) Perturbación por pasos de ±5%~10% income de 15000 a 14250, 15750 Simular pequeñas fluctuaciones de ingresos
Numérica (Discreta) Perturbación de niveles adyacentes education_level de 3 (Licenciatura) a 2 (Técnico) o 4 (Maestría) Simular ajustes menores en el nivel educativo
Categórica Reemplazar con otros valores de nivel similar industry de "Manufactura" a "Construcción" Simular cambios cercanos en la industria
Temporal Avance/retroceso de días fijos last_login_days de 30 a 25 o 35 Simular cambios en la frecuencia de inicio de sesión

En términos de implementación de código, no codificamos las lógicas de perturbación, sino que definimos una clase PerturbationStrategy y registramos estrategias por nombre de característica:

class PerturbationStrategy:
    def __init__(self):
        self.strategies = {}
    
    def register(self, feature_name, strategy_func):
        self.strategies[feature_name] = strategy_func
    
    def perturb(self, x, feature_name):
        if feature_name in self.strategies:
            return self.strategies[feature_name](x)
        else:
            raise ValueError(f"No perturbation strategy for {feature_name}")

# Registrar estrategias empresariales
perturb_strat = PerturbationStrategy()
perturb_strat.register('income', lambda x: [x*0.95, x*1.05])
perturb_strat.register('education_level', lambda x: [max(1, x-1), min(5, x+1)])
perturb_strat.register('industry', lambda x: ['Construcción', 'Finanzas'] if x=='Manufactura' else ['Manufactura'])

De esta manera, cuando el departamento de negocios dice: "Quiero saber qué pasaría si los ingresos del cliente bajan un 10%", el explicador llama directamente a perturb_strat.perturb(15000, 'income'), que devuelve [14250]. El modelo se alimenta con este valor y se observa el cambio en la salida. Todo el proceso es auditable, reproducible y explicable.

3.5 Corrección de Gradientes para el Rastreo de Atribución: Por Qué los Gradientes Originales Pueden Engañar al Juicio Empresarial

Calcular directamente el gradiente de las características de entrada con respecto a la salida usando torch.autograd.grad produce un vector cuyo valor numérico se ve en gran medida afectado por la escala de las características. Por ejemplo, la unidad de income es "yuanes", y el gradiente podría ser 1e-5; la unidad de age es "años", y el gradiente podría ser 0.3. Una simple comparación de magnitudes podría llevar erróneamente a la conclusión de que age es más importante. Adoptamos el concepto de Integrated Gradients (IG), pero con adaptaciones empresariales:

  • La línea base (baseline) no se establece en cero (que no tiene sentido para ingresos=0), sino en la media de la industria (por ejemplo, ingresos medios para trabajadores de manufactura);
  • El número de pasos de integración se fija en 50 para garantizar la precisión;
  • La puntuación de atribución final = gradient * (feature_value - baseline_value), eliminando el impacto de la escala.
def integrated_gradients(model, input_tensor, baseline_tensor, steps=50):
    # input_tensor: forma (1, n_features)
    # baseline_tensor: forma (1, n_features), por ejemplo, media de la industria
    scaled_inputs = [baseline_tensor + (float(i)/steps)*(input_tensor-baseline_tensor) 
                     for i in range(steps+1)]
    
    grads = []
    for scaled in scaled_inputs:
        scaled.requires_grad = True
        output = model(scaled)
        grad = torch.autograd.grad(output.sum(), scaled)[0]
        grads.append(grad.detach().numpy())
    
    # Promedio de gradientes, multiplicado por delta
    avg_grads = np.average(grads[:-1], axis=0) # excluir el último
    delta = (input_tensor - baseline_tensor).detach().numpy()
    attribution = avg_grads * delta
    return attribution.flatten()

# Uso de ejemplo
baseline = torch.tensor([[industry_mean_income, industry_mean_age, ...]])
attribution = integrated_gradients(model, x_single, baseline)
# attribution[i] es la puntuación de atribución empresarial para la i-ésima característica, se puede ordenar directamente

Esta puntuación de atribución se introduce directamente en el sistema de informes empresariales, generando la página "Perfil del Cliente - Atribución de Decisión": la parte superior muestra "Rechazado", y la parte inferior enumera las características con las puntuaciones de atribución más altas y sus valores específicos, como "Número de impagos = 3 (puntuación de atribución 0.62)" y "Ratio Deuda/Ingresos = 85% (puntuación de atribución 0.28)". Los usuarios empresariales identifican rápidamente los puntos clave.

3.6 Formato de Salida de Explicación: Por Qué JSON No Es Suficiente, Se Necesita Salida Doble YAML + Markdown

La salida del explicador no puede ser un simple JSON plano como {"feature": "overdue_count", "value": 3, "attribution": 0.62}. Los sistemas empresariales, las plataformas regulatorias y los portales de clientes requieren diferentes formatos:

  • Auditoría Regulatoria: Requiere YAML estructurado e inmutable, con marca de tiempo, versión del modelo e ID del operador.
  • Paneles de BI Internos: Requiere JSON validado por un esquema JSON, con nombres de campo que coinciden estrictamente con el almacén de datos.
  • Mensajes SMS de Notificación al Cliente: Requiere Markdown extremadamente conciso, que se renderiza automáticamente como una tarjeta de texto enriquecido.

Desarrollamos un ExplanationFormatter que toma la salida de atribución de manera unificada y genera según sea necesario:

import yaml
from datetime import datetime
# Asumiendo que get_current_user() y render_markdown están definidos

class ExplanationFormatter:
    def __init__(self, model_version):
        self.model_version = model_version
        
    def to_yaml(self, explanation_dict):
        # Añadir campos de auditoría
        audit_data = {
            'timestamp': datetime.now().isoformat(),
            'model_version': self.model_version,
            'operator_id': get_current_user(),
            'explanation': explanation_dict
        }
        return yaml.dump(audit_data, allow_unicode=True)
    
    def to_markdown(self, explanation_dict):
        # Generar texto amigable para el cliente
        md = f"## Explicación de la Decisión\n\nSu solicitud no fue aprobada. Las razones principales son:\n\n"
        for feat in explanation_dict.get('top_features', [])[:3]:
            md += f"- **{feat['name']}**: {feat['value']} (Nivel de impacto: {feat['score']:.2f})\n"
        md += "\n> Nota: Esta explicación se genera automáticamente basándose en la información proporcionada. Para más detalles, por favor contacte a atención al cliente."
        return md

# Ejemplo de salida
formatter = ExplanationFormatter(model_version="v2.3.1")
yaml_output = formatter.to_yaml(explanation_result)  # Guardar en la base de datos de auditoría
md_output = formatter.to_markdown(explanation_result)  # Enviar a la aplicación del cliente

Este mecanismo nos permitió pasar la revisión de materiales para el "Registro de Algoritmos" de la Oficina de Información del Ciberespacio en un proyecto gubernamental, ya que todas las salidas de explicación venían con información completa de rastreo.

3.7 Prueba de Carga de Rendimiento: La Trampa Oculta de Por Qué la Explicación de Muestra Única Debe Probarse con 1000 Concurrencias

Muchos equipos solo prueban el tiempo de ejecución de la explicación de una sola muestra, sintiéndose aliviados al ver "200 ms". Sin embargo, el entorno de producción es concurrente. Descubrimos un problema grave: shap.TreeExplainer compite por una cache compartida en entornos multihilo, lo que provoca un aumento de la CPU y una gran fluctuación de la latencia. Durante una prueba de carga, la latencia P95 se disparó a 3.2 segundos con 100 concurrencias, activando el disyuntor. La solución es: todos los explicadores deben ser sin estado y sin caché compartida. Reemplazamos TreeExplainer con nuestro motor de reglas ONNXRuleEngine desarrollado internamente, que no mantiene ningún estado en tiempo de ejecución y cada llamada es una ejecución puramente funcional:

import onnxruntime as ort

class ONNXRuleEngine:
    def __init__(self, onnx_path):
        # Solo carga el grafo ONNX, no inicializa ningún estado de runtime
        self.onnx_path = onnx_path
        self.graph = onnx.load(onnx_path)
        
    def explain(self, x_input_data):
        # Cada explain crea una nueva sesión para evitar competencia de hilos
        session = ort.InferenceSession(self.onnx_path)
        
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        # Ejecución pura, sin efectos secundarios
        pred = session.run([output_name], {input_name: x_input_data})[0]
        
        # Lógica de coincidencia de reglas (sin estado)
        rules = self.match_rules(x_input_data, self.graph) # Implementación de match_rules omitida por brevedad
        return {'prediction': pred, 'rules': rules}

# Prueba de rendimiento
# ... (código de prueba de carga omitido) ...

Resultado de la prueba de rendimiento: 180 ms para una sola muestra, P95 estable en 210 ms con 1000 concurrencias, utilización de CPU <40%. Esto es realmente desplegable.

Proceso Práctico: Construcción de un Módulo XAI Entregable Desde Cero (Ejemplo de Control de Riesgos de Crédito)

Ahora, unamos todos los puntos anteriores y recorramos el flujo completo de un proyecto real. Supongamos que te han asignado un modelo de crédito XGBoost ya implementado, y el departamento de negocios te pide "preparar una demostración para los reguladores la próxima semana sobre la capacidad de explicación para un solo cliente". Aquí está nuestro método estándar de siete pasos, con comandos, código y puntos de control para cada paso.

4.1 Paso 1: Preparación del Entorno y Dependencias (5 minutos)

No usamos conda, solo pip, para asegurar que la imagen Docker sea lo más pequeña posible:

# Crear un entorno virtual limpio
python -m venv xai_env
source xai_env/bin/activate  # Linux/Mac
# xai_env\Scripts\activate  # Windows

# Instalar dependencias centrales (versiones fijadas para evitar actualizaciones implícitas)
pip install \
    numpy==1.23.5 \
    pandas==1.5.3 \
    scikit-learn==1.2.2 \
    xgboost==1.7.5 \
    onnx==1.13.1 \
    onnxruntime==1.14.1 \
    skl2onnx==1.13.1 \
    pyyaml==6.0 \
    jinja2==3.1.2  # Para renderizado de plantillas Markdown

Experiencia Práctica: onnxruntime debe coincidir estrictamente con la versión de onnx. Usamos la combinación onnx==1.13.1 + onnxruntime==1.14.1, que ha sido la más estable en pruebas en AWS c5.2xlarge y Alibaba Cloud ecs.g7.2xlarge. Una vez, un error al cargar un modelo ONNX de LightGBM causado por onnxruntime==1.15.0, con el error InvalidGraph: This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (label) from node () does not match expected type 'tensor(float)', se resolvió con una simple degradación de versión.

4.2 Paso 2: Inmutabilización del Modelo y Preprocesador (15 minutos)

Supongamos que tienes model.pkl y preprocessor.pkl. Primero, verifica si se pueden cargar correctamente:

import joblib
import pandas as pd

# Cargar modelo y preprocesador originales
model = joblib.load('model.pkl')
preprocessor = joblib.load('preprocessor.pkl')

# Construir una muestra de prueba (debe tener la misma estructura que durante el entrenamiento)
test_sample = pd.DataFrame({
    'age': [35],
    'income': [15000],
    'overdue_count': [3],
    'employment_length': [5],
    'has_car': [1]
})

# Verificar si el preprocesamiento + predicción son consistentes
X_test = preprocessor.transform(test_sample)
pred_raw = model.predict_proba(X_test)[:, 1]
print(f"Probabilidad de predicción original: {pred_raw[0]:.4f}") # Debe coincidir con la línea

Luego, exportar a ONNX:

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnx

# Definir el tipo de entrada ONNX (Nota: debe coincidir con la dimensión de salida del preprocesador)
n_features = X_test.shape[1]
initial_type = [('float_input', FloatTensorType([None, n_features]))]

# Convertir el modelo (Nota: el modelo debe ser de estilo sklearn; XGBoost debe ser XGBClassifier, no Booster)
onx = convert_sklearn(model, initial_types=initial_type, target_opset=12)

# Guardar
with open("model.onnx", "wb") as f:
    f.write(onx.SerializeToString())

# Exportar también el preprocesador a ONNX (skl2onnx admite Pipelines)
from sklearn.pipeline import Pipeline
full_pipeline = Pipeline([('preprocessor', preprocessor), ('model', model)])
onx_full = convert_sklearn(full_pipeline, initial_types=[('input', FloatTensorType([None, len(test_sample.columns)]))])
with open("pipeline.onnx", "wb") as f:
    f.write(onx_full.SerializeToString())

Punto de Control: Usa onnx.checker.check_model(onnx.load("model.onnx")) para verificar la validez del archivo ONNX. Si arroja un error como InvalidGraph: Node input cannot be empty, probablemente el parámetro target_opset de convert_sklearn es demasiado bajo; actualízalo a 12 o 13.

4.3 Paso 3: Extracción de Reglas y Verificación Empresarial (30 minutos)

Cargar ONNX y extraer reglas:

import onnx
from onnx import helper, numpy_helper
import numpy as np

def extract_xgb_rules(onnx_path, whitelist_features=['overdue_count', 'income']):
    model = onnx.load(onnx_path)
    graph = model.graph
    
    # Buscar nodos TreeEnsembleClassifier
    tree_nodes = [n for n in graph.node if n.op_type == 'TreeEnsembleClassifier']
    if not tree_nodes:
        raise ValueError("No TreeEnsembleClassifier found")
    
    rules = []
    for node in tree_nodes:
        # Analizar node.attribute para obtener información de división
        attrs = {a.name: a for a in node.attribute}
        # nodes_featureids, nodes_splits, nodes_falsenodeids, etc...
        # (La lógica de análisis específica se omite aquí; consulte xai-tools/onnx_rule_parser.py en GitHub)
        node_rules = parse_tree_node(attrs, whitelist_features)
        rules.extend(node_rules)
    
    # Filtrado de tres niveles
    filtered_rules = filter_rules(rules, min_coverage=0.001, min_confidence=0.8)
    merged = merge_similar_rules(filtered_rules)
    return merged

# Ejecutar
rules = extract_xgb_rules("model.onnx")
print(f"Extraídas {len(rules)} reglas centrales")
for r in rules[:3]:
    print(f"  {r}")

Ejemplo de salida:

SI overdue_count > 2 ENTONCES rechazar (confianza=0.92)
SI income < 8000 Y employment_length < 2 ENTONCES rechazar (confianza=0.85)
SI has_car == 0 Y overdue_count == 0 ENTONCES aprobar (confianza=0.88)

Verificar inmediatamente con el departamento de negocios. Una vez encontramos una regla IF age < 22 THEN reject, pero el manual de negocios decía age < 23. Inmediatamente rastreamos los datos de entrenamiento y descubrimos que todas las muestras de 22 años eran negativas, lo que provocó sobreajuste. Después de limpiar los datos, la regla desapareció.

4.4 Paso 4: Desrarollo del Analizador de Perturbaciones (20 minutos)

Escribir estrategias de perturbación según la semántica empresarial:

import numpy as np

class CreditPerturbator:
    def __init__(self):
        self.strategies = {}
        self._register_strategies()
    
    def _register_strategies(self):
        # Ingresos: ±5%
        self.strategies['income'] = lambda x: [x * 0.95, x * 1.05]
        # Número de impagos: ±1 (pero no inferior a 0)
        self.strategies['overdue_count'] = lambda x: [max(0, x-1), x+1]
        # Edad: ±2 años
        self.strategies['age'] = lambda x: [max(18, x-2), min(70, x+2)]
    
    def perturb(self, x, feature_name):
        if feature_name not in self.strategies:
            raise ValueError(f"Unknown feature {feature_name}")
        return self.strategies[feature_name](x)

# Pruebas
perturb = CreditPerturbator()
print(perturb.perturb(15000, 'income'))      # [14250.0, 15750.0]
print(perturb.perturb(3, 'overdue_count'))   # [2, 4]

4.5 Paso 5: Integración del Rastreador de Atribución (25 minutos)

Usar ONNX Runtime para rastrear gradientes (requiere habilitar enable_grad):

import onnxruntime as ort
import numpy as np
import torch # Se necesita PyTorch para IG
from onnx2pytorch import ConvertModel # Para convertir ONNX a PyTorch

class ONNXGradientTracker:
    def __init__(self, onnx_path):
        # ONNX Runtime no admite gradientes directamente; usamos un wrapper de PyTorch
        # Primero, carga ONNX como modelo PyTorch (requiere onnx2pytorch)
        self.torch_model = ConvertModel(onnx_path)
        self.torch_model.eval()
    
    def integrated_gradients(self, x, baseline, steps=50):
        # x: array numpy (1, n_features)
        # baseline: array numpy (1, n_features)
        x_tensor = torch.tensor(x, dtype=torch.float32, requires_grad=True)
        baseline_tensor = torch.tensor(baseline, dtype=torch.float32)
        
        # Cálculo de IG (igual que en la sección 3.5)
        # ... (reutilizar la función IG anterior) ...
        return attribution

# Obtener la línea base de la industria (consultar de la base de datos empresarial)
baseline = np.array([[38, 12000, 0, 4, 1]])  # age, income, overdue, emp_len, car

tracker = ONNXGradientTracker("model.onnx")
attribution = tracker.integrated_gradients(
    x=test_sample.values.astype(np.float32), 
    baseline=baseline
)
print("Puntuaciones de atribución:", attribution)

4.6 Paso 6: Formato de Salida y Empaquetado de API (15 minutos)

Usar Flask para crear rápidamente una API de explicación:

from flask import Flask, request, jsonify
import yaml
from jinja2 import Template

app = Flask(__name__)

# Asegúrate de que estas variables/funciones estén definidas en tu entorno
# predict_onnx, match_rules, perturb, tracker, baseline, test_sample, etc.

@app.route('/explain', methods=['POST'])
def explain():
    data = request.json
    x_input = np.array(data['features']).reshape(1, -1)  # [age, income, ...]
    
    # 1. Coincidencia de reglas
    rules = match_rules(x_input, "model.onnx") # Asegúrate de que match_rules use el grafo ONNX
    
    # 2. Análisis de perturbaciones
    perturb_results = {}
    perturbator = CreditPerturbator() # Instanciar perturbador
    feature_names = ['age', 'income', 'overdue_count', 'employment_length', 'has_car'] # Mapeo de índice a nombre
    
    for i, feat_name in enumerate(feature_names):
        if feat_name in perturbator.strategies: # Solo perturbar si hay estrategia definida
            try:
                original_value = x_input[0, i]
                perturbed_values = perturbator.perturb(original_value, feat_name)
                for p_val in perturbed_values:
                    x_perturbed = x_input.copy()
                    x_perturbed[0, i] = p_val
                    # Asegúrate de que predict_onnx tome la entrada correcta (por ejemplo, array numpy)
                    pred_p = predict_onnx(x_perturbed, "model.onnx") # Implementación de predict_onnx omitida
                    perturb_results[f"{feat_name}_to_{p_val}"] = float(pred_p[0]) # Asegúrate de que sea un float serializable
            except ValueError: # Ignorar características sin estrategia
                pass

    # 3. Puntuaciones de atribución
    # Asegúrate de que 'tracker' esté inicializado y 'baseline' sea un tensor numpy correcto
    attr = tracker.integrated_gradients(x_input, baseline)
    
    # 4. Formato
    # Crear un dict para las características top, asumiendo que 'attr' está en el orden correcto
    top_features_data = []
    # Supongamos que tenemos una forma de obtener los nombres de características correspondientes a 'attr'
    feature_names_for_attr = ['age', 'income', 'overdue_count', 'employment_length', 'has_car'] # Ejemplo
    
    # Crear pares (nombre, puntuación) y ordenarlos
    attributed_features = sorted(zip(feature_names_for_attr, attr), key=lambda item: item[1], reverse=True)
    
    for name, score in attributed_features[:3]: # Tomar los 3 principales
        # Necesitamos encontrar el valor original de la característica
        try:
            idx = feature_names_for_attr.index(name)
            value = x_input[0, idx]
            top_features_data.append({'name': name, 'value': value.item() if isinstance(value, np.ndarray) else value, 'score': float(score)})
        except ValueError:
            # El nombre de la característica no se encontró en la lista de mapeo
            pass
            
    explanation = {
        'customer_id': data.get('id'),
        # Asegúrate de que 'pred_raw' esté disponible aquí o recalcula la predicción base
        'prediction': float(pred_raw[0]) if 'pred_raw' in locals() else predict_onnx(x_input, "model.onnx")[0], 
        'top_features': top_features_data,
        'rules_triggered': rules[:2], # Mostrar las 2 reglas principales activadas
        'perturbation_analysis': perturb_results
    }
    
    # Devolver YAML (auditoría) y Markdown (cliente)
    # Asegúrate de que yaml.dump y render_markdown estén disponibles
    return jsonify({
        'yaml': yaml.dump(explanation, allow_unicode=True, sort_keys=False), # sort_keys=False para mantener el orden
        'markdown': render_markdown(explanation) # Implementación de render_markdown omitida
    })

# Módulos/Funciones auxiliares necesarios:
# predict_onnx(input_data, onnx_model_path): Ejecuta inferencia ONNX
# match_rules(input_data, onnx_graph): Implementa la lógica de coincidencia de reglas
# render_markdown(explanation_dict): Renderiza el dict de explicación en Markdown
# Asegúrate de que 'pred_raw', 'tracker', 'baseline', 'feature_names_for_attr' estén definidos y accesibles

if __name__ == '__main__':
    # Carga de modelo y configuración inicial aquí si es necesario para la ejecución local
    # Por ejemplo:
    # tracker = ONNXGradientTracker("model.onnx")
    # baseline = np.array([[38, 12000, 0, 4, 1]]) 
    # pred_raw = ... (calcular predicción base si se usa)
    app.run(host='0.0.0.0', port=5000) # Usar puerto estándar 5000

render_markdown usa una plantilla Jinja2:

## Explicación de su Solicitud de Crédito

Basado en la información proporcionada, el sistema ha determinado que su solicitud **NO APROBADA**.

### Factores Principales de Impacto
{% for feat in explanation.top_features %}
- **{{ feat.name }}**: {{ feat.value }} (Peso de Impacto: {{ feat.score|round(2) }})
{% endfor %}

### Reglas de Control de Riesgos Activadas
{% for rule in explanation.rules_triggered %}
- {{ rule }}
{% endfor %}

> Esta explicación se genera automáticamente solo como referencia. Para una revisión manual, por favor, contacte nuestra línea de atención al cliente.

4.7 Paso 7: Pruebas de Carga y Despliegue (10 minutos)

Usar locust para escribir un script simple de prueba de carga:

# locustfile.py
from locust import HttpUser, task, between
import json

class XAIUser(HttpUser):
    wait_time = between(1, 3)
    
    @task
    def explain(self):
        payload = {
            "id": "CUST_123456",
            "features": [35, 15000, 3, 5, 1]  # age, income, overdue, emp_len, car
        }
        self.client.post("/explain", json=payload)

# Ejecutar: locust -f locustfile.py --host http://localhost:5000

Objetivo: 100 concurrencias, latencia P95 <300 ms, tasa de error 0%. Una vez alcanzado, Dockerizar:

FROM python:3.9-slim
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . /app
WORKDIR /app
# Usar gunicorn para servir la aplicación Flask
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "your_flask_app_module:app"]

Construir y enviar, desplegar en K8s. Todo el proceso, desde obtener el modelo hasta la API implementada, nuestro récord del equipo es de 4 horas y 17 minutos.

Preguntas Frecuentes y Consejos de Depuración: Los Errores No Documentados

La implementación de XAI no siempre es fácil. Los siguientes problemas son los que encontramos con frecuencia en nuestros 11 proyectos, cada uno con registros de errores reales, análisis de causa raíz y soluciones rápidas. No se encuentran en ninguna documentación oficial, pero pueden ahorrarte dos semanas de problemas.

5.1 Problema: Error al Cargar el Modelo ONNX InvalidGraph: Node () has input edge () which is not produced by any node, pero onnx.checker dice que es válido.

  • Síntoma: onnx.load("model.onnx") tiene éxito, pero ort.InferenceSession("model.onnx") falla, indicando que falta un borde de entrada.
  • Causa Raíz: Cuando skl2onnx convierte el modelo, si hay entradas no utilizadas (como el parámetro base_score en algunos modelos de árboles), genera un nodo Constant aislado cuya salida no es consumida por ningún nodo. ONNX checker lo considera válido (una constante no tiene por qué ser utilizada), pero ONNX Runtime valida estrictamente la conectividad del grafo.
  • Depuración: Abre el archivo ONNX con Netron, examina todos los nodos y busca nodos de tipo Constant, verificando si sus salidas están "colgando".
  • Solución: Actualiza skl2onnx a la versión 1.14+ y, al convertir, añade el parámetro final_types=[('input', FloatTensorType(...)), ('unused_input', FloatTensorType([None, 0]))] para declarar explícitamente todas las entradas, evitando la generación de nodos colgantes.

5.2 Problema: El rendimiento de la explicación de TreeExplainer se degrada drásticamente bajo carga concurrente, aunque el uso de CPU no supera el 50%.

Etiquetas: XAI Inteligencia Artificial Explicable Python Modelos Black-Box ONNX

Publicado el 6-8 16:12