Entendiendo los componentes del error mediante la descomposición de sesgo-varianza
La descomposición de sesgo-varianza es una técnica crucial para evaluar modelos de aprendizaje automático. Permite degslosar el error total en tres partes: sesgo al cuadrado, varianza y error irreducible. Esta descomposición ayuda a identificar si el problema de un modelo radica en su simplicidad (sesgo alto), en su sensibilidad a los datos de entrenamiento (varianza alta) o en el ruido inherente de los datos.
Componentes teóricos
El error esperado de un modelo \(\hat{f}\) en un punto \(x\) se puede expresar matemáticamente como:
Implementación práctica en Python
A continuación, se muestra una implementación manual de la descomposición. Utiliza remuestreo bootstrap para aproximar las expectativas y calcula los tres componentes del error.
import numpy as np
def calcular_descomposicion_sesgo_varianza(X_entrenamiento, y_entrenamiento,
X_prueba, y_prueba,
Modelo, n_muestras=100, semilla_aleatoria=42):
"""
Calcula la descomposición de sesgo-varianza para un modelo dado.
Parámetros:
X_entrenamiento, y_entrenamiento: Datos de entrenamiento.
X_prueba, y_prueba: Datos de prueba.
Modelo: Clase del modelo (ej. LinearRegression).
n_muestras: Número de remuestreos bootstrap.
semilla_aleatoria: Semilla para reproducibilidad.
Retorna:
Diccionario con sesgo2, varianza, error_irreducible, error_total.
"""
np.random.seed(semilla_aleatoria)
n_prueba = len(y_prueba)
predicciones_acumuladas = np.zeros((n_muestras, n_prueba))
for muestra in range(n_muestras):
# Remuestreo bootstrap con reemplazo
indices_bootstrap = np.random.choice(len(y_entrenamiento),
size=len(y_entrenamiento),
replace=True)
X_bootstrap = X_entrenamiento[indices_bootstrap]
y_bootstrap = y_entrenamiento[indices_bootstrap]
# Entrenar el modelo en la muestra bootstrap
modelo_instanciado = Modelo()
modelo_instanciado.fit(X_bootstrap, y_bootstrap)
# Predecir en el conjunto de prueba
predicciones_acumuladas[muestra, :] = modelo_instanciado.predict(X_prueba)
# Calcular componentes para cada punto de prueba
sesgo2 = np.zeros(n_prueba)
varianza = np.zeros(n_prueba)
error_total = np.zeros(n_prueba)
for idx in range(n_prueba):
preds_punto = predicciones_acumuladas[:, idx]
media_predicciones = np.mean(preds_punto)
# Error total: (valor real - predicción promedio)^2
error_total[idx] = (y_prueba[idx] - media_predicciones) ** 2
# Sesgo al cuadrado: (predicción promedio - valor real)^2
sesgo2[idx] = (media_predicciones - y_prueba[idx]) ** 2
# Varianza: dispersión de las predicciones
varianza[idx] = np.var(preds_punto, ddof=0)
# Error irreducible por diferencia
error_irreducible = error_total - sesgo2 - varianza
return {
'sesgo2': sesgo2,
'varianza': varianza,
'error_irreducible': error_irreducible,
'error_total': error_total,
'prediccion_media': np.mean(predicciones_acumuladas, axis=0)
}
Ejemplo con datos sintéticos
Generamos un conjunto de datos donde la relación real incluye un término de interacción que no se proporciona al modelo lineal. Esto introduce un sesgo sistemático.
import pandas as pd
from sklearn.linear_model import LinearRegression
# Generar datos con interacción oculta
np.random.seed(42)
n_datos = 500
caracteristica1 = np.random.randn(n_datos)
caracteristica2 = np.random.randn(n_datos)
# Relación real: y = 3*feat1 + 2*feat2 + 5*feat1*feat2 + ruido
y_real = 3 * caracteristica1 + 2 * caracteristica2 + 5 * caracteristica1 * caracteristica2 + np.random.randn(n_datos) * 2
# Dividir en entrenamiento y prueba
X_entrenamiento = np.column_stack([caracteristica1[:400], caracteristica2[:400]])
y_entrenamiento = y_real[:400]
X_prueba = np.column_stack([caracteristica1[400:], caracteristica2[400:]])
y_prueba = y_real[400:]
# Realizar descomposición
resultados = calcular_descomposicion_sesgo_varianza(
X_entrenamiento, y_entrenamiento, X_prueba, y_prueba,
LinearRegression, n_muestras=50
)
# Mostrar promedios
print("Promedio de sesgo²:", np.mean(resultados['sesgo2']))
print("Promedio de varianza:", np.mean(resultados['varianza']))
print("Promedio de error irreducible:", np.mean(resultados['error_irreducible']))
En este ejemplo, el sesgo² será dominante porque el modelo lineal no puede capturar la interacción. Al agregar características polinómicas, el sesgo disminuye y la varianza aumenta.
Aplicación a diferentes modelos
- Modelos lineales: Tienden a tener sesgo alto cuando la relación es no lineal, pero varianza baja gracias a la regularización.
- Árboles de decisión: Pueden tener varianza alta debido a su sansibilidad a los datos de entrenamiento, pero sesgo bajo si son suficientemente profundos.
- Ensamblados (ej. Random Forest): Reducen la varianza mediante bagging, manteniendo el sesgo bajo.
Solución de problemas comunes
Si la varianza calculada es cero, verifique que el remuestreo bootstrap sea efectivo y que el modelo no sea determinista. Si el error irreducible es negativo, asegúrese de que el error total se calcule como \((y - \bar{f})^2\) y no como el promedoi de errores individuales.
Para conjuntos de datos grandes, procese los datos en bloques para evitar problemas de memoria:
def calcular_descomposicion_por_bloques(X_entrenamiento, y_entrenamiento,
X_prueba, y_prueba, Modelo,
n_muestras=50, tamano_bloque=1000):
n_prueba = len(y_prueba)
sesgo2_total = 0
varianza_total = 0
error_irreducible_total = 0
for inicio in range(0, n_prueba, tamano_bloque):
fin = min(inicio + tamano_bloque, n_prueba)
X_bloque = X_prueba[inicio:fin]
y_bloque = y_prueba[inicio:fin]
resultados_bloque = calcular_descomposicion_sesgo_varianza(
X_entrenamiento, y_entrenamiento, X_bloque, y_bloque,
Modelo, n_muestras
)
sesgo2_total += np.sum(resultados_bloque['sesgo2'])
varianza_total += np.sum(resultados_bloque['varianza'])
error_irreducible_total += np.sum(resultados_bloque['error_irreducible'])
return {
'sesgo2_promedio': sesgo2_total / n_prueba,
'varianza_promedio': varianza_total / n_prueba,
'error_irreducible_promedio': error_irreducible_total / n_prueba
}
Consejos para la práctica
Use un número de remuestreos bootstrap de al menos 50 para estimaciones estables. En modelos con aleatoriedad adicional (ej. redes neuronales), aumente este número y controle las semillas. Monitoree los componentes del error en producción para detectar cambios en la distribución de datos o problemas de calidad.