Módulo DDI Publicado en AAAI 2025: Fusión de Información Multiescala Mediante Doble MLP

Los métodos predominantes en el pronóstico de series temporales (TSF), como los basados en Transformer y MLP, enfrentan desafíos al tratar con datos del mundo real complejos. Las series temporales del mundo real suelen exhibir patrones diversos en diferentes escalas, y los cambios futuros están determinados por la interacción de estas escalas superpuestas. Los modelos Transformer, aunque eficaces para capturar dependencias a largo plazo, tienen alta complejidad computacional y son propensos al sobreajuste. Los modelos MLP son computacionalmente eficientes, pero les cuesta capturar patrones temporales en escalas complejas.

Para abordar este fenómeno de "entrelazamiento multiescala", se propone un novedoso marco de Descomposición Multiescala Adaptativa (AMD) basado en MLP. Este marco primero descompone la serie temporal en diferentes patrones temporales en múltiples escalas, luego modela las dependencias temporales y de canal en estos patrones, y finalmente utiliza la autocorrelación para agregar adaptativamente los resultados del pronóstico multiescala. El método tiene como objetivo superar las limitaciones de los enfoques existentes, mejorando la precisión del pronóstico al identificar y utilizar con precisión los patrones temporales dominantes, manteniendo al mismo tiempo la eficiencia computacional.

Información Básica del Artículo

  • Título: Adaptive Multi-Scale Decomposition Framework for Time Series Forecasting (Marco de Descomposición Multiescala Adaptativa para Pronóstico de Series Temporales)
  • Módulos Principales: Mezcla Descomponible Multiescala (MDM), Interacción de Dependencia Dual (DDI), Síntesis Adaptativa Multi-Predictor (AMS)

Marco Algorítmico y Módulos Principales

2.1 Marco Algorítmico

El marco AMD propuesto comienza con la serie temporal de entrada X que pasa por el módulo MDM para una descomposición multiescala y mezcla de información, generando un tensor U enriquecido con información multiescala. Luego, el módulo DDI procesa U, modelando simultáneamente las dependencias en las dimensiones temporal y de canal. Finalmente, el módulo AMS genera pesos de predictor independientes S para cada canal y realiza una suma ponderada de la salida de DDI para obtener el resultado final del pronóstico Y. Todo el marco se entrena de extremo a extremo.

2.2 Módulos Principales

Módulo 1: Mezcla Descomponible Multiescala (MDM)
  • Función Principal: Descomponer la serie temporal original en múltiples patrones de diferentes granularidades (de fino a grueso) y luego mezclar la información de estos patrones nivel por nivel, proporcionando una representación mejorada con información multiescala para los módulos posteriores.
  • Lógica de Implementación: El proceso de descomposición se logra mediante múltiples aplicaciones de promedio de agrupamiento (Average Pooling) en la serie de entrada, extrayendo capa por capa patrones de granularidad más gruesa τ. El proceso de mezcla utiliza una red residual feedforward de arriba hacia abajo (de grueso a fino) para incorporar la información de grano grueso ξ capa por capa en los patrones de grano fino.
    • Lógica de descomposición:
      τ_i = AvgPooling(τ_{i-1})
    • Lógica de mezcla:
      ξ_i = τ_i + MLP(ξ_{i+1})
  • Ventajas: Comparado con la descomposición tradicional estacional-tendencia, MDM puede descomponer la serie de manera más flexible en información de múltiples escalas complementarias, proporcionando una vista de la serie más completa y detallada, lo que mejora la capacidad del modelo para expresar dinámicas temporales complejas.
Módulo 2: Interacción de Dependencia Dual (DDI)
  • Función Principal: Apuntar a modelar simultáneamente las dependencias temporales y las dependencias de canal en la información multiescala mezclada de salida del módulo MDM.
  • Lógica de Implementación: Este módulo recibe la información multiescala U y la procesa por bloques (Patching). Utiliza alternativamente dos MLP: uno compartido en la dimensión temporal para mezclar información temporal, y otro compartido en la dimensión de canal (logrado mediante transposición de matriz) para mezclar información de canales. Se mantienen conexiones residuales. Se introduce un factor de escala β para equilibrar la importancia de ambas dependencias.
    • Mezcla temporal:
      Z_{t+P_t} = \hat{U}_{t+P_t} + MLP(\hat{V}_{t_{t-P}})
    • Mezcla de canal:
      \hat{V}_{t+P_t} = Z_{t+P_t} + β \cdot MLP((Z_{t+P_t})^T)^T
  • Ventajas: DDI puede capturar explícitamente las interacciones dinámicas entre diferentes pasos de tiempo y diferentes variables. El factor de escala β introducido permite que el modelo ajuste adaptativamente el grado de atención a las dependencias entre canales, suprimiendo el ruido de variables no relacionadas al mismo tiempo que se mejora la interacción de la información, aumentando así la robustez del modelo.
Módulo 3: Síntesis Adaptativa Multi-Predictor (AMS)
  • Función Principal: Adopta el concepto de Mezcla de Expertos (MoE), entrenando predictores especializados para diferentes patrones temporales y asignando dinámicamente pesos a estos predictores según la entrada, finalmente sintetizando el resultado pronosticado mediante una suma ponderada.
  • Lógica de Implementación: Este módulo consta de dos partes: 1) Selector de Patrones Temporales (TP-Selector) y 2) Proyección de Patrones Temporales (TP-Projection). TP-Selector genera pesos S para múltiples predictores a través de una red de compuerta con ruido. TP-Projection contiene múltiples predictores paralelos (todos MLP), cuyas salidas se multiplican por los pesos S generados por TP-Selector y se suman para obtener el valor pronosticado final.
    • Generación de pesos del selector (parcial):
      S = Softmax(TopK(Softmax(Q(u)), k))
    • Síntesis del valor pronosticado:
      \hat{y} = \sum_{j=0}^{m} S_j \cdot Predictor_j(v)
  • Ventajas: Comparado con un simple promedio de toda la información multiescala, AMS puede identificar y enfocarse en el patrón temporal dominante que influye en el futuro en diferentes períodos de tiempo, logrando un ajuste dinámico de la estrategia de pronóstico. Este mecanismo de ponderación adaptativa no solo mejora la precisión del pronóstico, sino que también mejora la interpretabilidad del modelo.

Tareas de Aplicación del Módulo

  • Escenario de Aplicación Principal: Pronóstico de series temporales largas y pronóstico de series temporales cortas. Especialmente adecuado para datos de series temporales del mundo real que exhiben múltiples periodicidades, tendencias y cambios dinámicos, como en los campos de meteorología, electricidad, flujo de tráfico y finanzas.
  • Metodología Principal: La idea esencial es "Descomponer-Modelar-Sintetizar Adaptativamente". Convierte un problema de pronóstico complejo, mediante la descomposición en diferentes escalas temporales, en un conjunto de subproblemas más simples para su modelado, y finalmente utiliza un mecanismo inteligente e impulsado por datos de mezcla de expertos para combinar dinámicamente las soluciones de los subproblemas, obteniendo así la solución global óptima final.
  • Extensión Inspiradora:
    1. Descomposición Adaptativa en el Dominio de la Frecuencia: Combinando trabajos como FITS mencionados en el artículo, en el futuro se podría explorar la descomposición y extracción adaptativa de patrones en el dominio de la frecuencia para reemplazar la descomposición actual en el dominio del tiempo basada en promedio de agrupamiento, lo que podría capturar características periódicas de manera más eficiente.
    2. Mejora del Problema de Desplazamiento de Distribución en Redes Profundas: El artículo utiliza RevIN en la capa de entrada para manejar el desplazamiento de distribución, pero señala que el cambio de distribución en redes profundas sigue siendo un desafío. En el futuro, se pueden investigar técnicas de normalización más avanzadas para resolver problemas de estabilidad en estructuras más profundas del marco AMD.

Resultados Experimentales y Análisis de Visualización

Experimentos y Conclusiones Principales

El experimento central que mejor refleja la contribución de este artículo es el estudio de ablación de componentes del modelo (Ablation Study), especialmente el análisis del módulo AMS.

  • Propósito del Experimento: Este experimento tiene como objetivo verificar la necesidad y eficacia de cada módulo principal (MDM, DDI, AMS) en el marco AMD. Específicamente, para validar si la estrategia de ponderación adaptativa del módulo AMS supera a otros métodos de agregación (como ordenación aleatoria, ponderación promedio o un solo predictor) y demostrar que la mejora del rendimiento no se debe únicamente al aumento en la cantidad de parámetros del modelo.
  • Resultados Clave:
    • w/o AMS: Al eliminar el módulo AMS y reemplazarlo por un único predictor lineal, el rendimiento del modelo disminuyó significativamente, demostrando las ventajas de la arquitectura multi-predictor.
    • Peso Promedio: Al reemplazar la ponderación adaptativa de AMS con una simple ponderación promedio, el rendimiento también disminuyó. Esto indica que asignar igual importancia a diferentes patrones temporales es menos efectivo que la estrategia de ponderación dinámica de AMS.
    • w/o MDM: Al eliminar el módulo MDM y usar solo una escala para el pronóstico, el error también aumentó, lo que demuestra la necesidad de la descomposición multiescala.
  • Conclusión del Autor: Basándose en este experimento, los autores concluyen que el rendimiento superior del marco AMD proviene efectivamente de su mecanismo único de descomposición multiescala y síntesis adaptativa. A diferencia de los mecanismos de autoatención que dependen de la invariencia al orden, AMS puede preservar la información temporal de la serie y utilizar eficazmente los patrones dominantes que cambian con el tiempo. Comparado con la simple agregación promedio, AMS logra pronósticos más precisos al asignar adaptativamente pesos a diferentes patrones temporales. Esto demuestra la razonabilidad y eficiencia del diseño del marco AMD.

Código del Módulo Plug-and-Play

Mezcla Descomponible Multiescala (MDM)

  • Función principal: Descomponer la serie por múltiples escalas mediante promedio de agrupamiento y mezclarlas de vuelta a una representación de grano fino usando un feedforward residual por niveles.
  • Ventaja principal: Obtener información multiescala complementaria de manera ligera, mejorando la capacidad de expresión para series temporales complejas.
  • Código principal (fragmento):
class MDM(nn.Module):
    def __init__(self, input_shape, niveles=3, factor=2, norma_capa=True):
        super(MDM, self).__init__()
        self.longitud_sec = input_shape[0]
        self.niveles = niveles
        if self.niveles > 0:
            self.escalas_lista = [factor ** i for i in range(niveles, 0, -1)]
            self.capas_agrupamiento = nn.ModuleList([nn.AvgPool1d(kernel_size=escala, stride=escala) for escala in self.escalas_lista])
            self.lineales = nn.ModuleList(
                [
                    nn.Sequential(nn.Linear(self.longitud_sec // escala, self.longitud_sec // escala),
                                  nn.GELU(),
                                  nn.Linear(self.longitud_sec // escala, self.longitud_sec * factor // escala),
                                  )
                    for escala in self.escalas_lista
                ]
            )
        self.norma_capa = norma_capa
        if self.norma_capa:
            self.norma = nn.BatchNorm1d(input_shape[0] * input_shape[-1])

    def forward(self, x):
        if self.norma_capa:
            x = self.norma(torch.flatten(x, 1, -1)).reshape(x.shape)
        if self.niveles == 0:
            return x
        # x [tamaño_lote, num_caracteristicas, longitud_secuencia]
        patrones_escala = []
        for i, escala in enumerate(self.escalas_lista):
            patrones_escala.append(self.capas_agrupamiento[i](x))
        patrones_escala.append(x)
        num_patrones = len(patrones_escala)
        for i in range(num_patrones - 1):
            transformado = self.lineales[i](patrones_escala[i])
            patrones_escala[i + 1] = torch.add(patrones_escala[i + 1], transformado, alpha=1.0)
        # [tamaño_lote, num_caracteristicas, longitud_secuencia]
        return patrones_escala[num_patrones - 1]

Interacción de Dependencia Dual (DDI)

  • Función principal: Mezclar alternativamente y superponer residuos en la dimensión temporal y de canal, modelando la doble dependencia por bloques.
  • Ventaja principal: Capturar explícitamente la interacción dinámica entre pasos temporales y variables, y suprimir el ruido irrelevante con intensidad ajustable.
  • Código principal (fragmento):
class DDI(nn.Module):
    def __init__(self, input_shape, tasa_abandono=0.2, ventana=12, beta=0.0, norma_capa=True):
        super(DDI, self).__init__()
        # input_shape[0] = longitud_secuencia    input_shape[1] = num_caracteristicas
        self.input_shape = input_shape
        if beta > 0.0:
            self.ff_dim = 2 ** math.ceil(math.log2(self.input_shape[-1]))
            self.bloque_fc = nn.Sequential(
                nn.Linear(self.input_shape[-1], self.ff_dim),
                nn.GELU(),
                nn.Dropout(tasa_abandono),
                nn.Linear(self.ff_dim, self.input_shape[-1]),
                nn.GELU(),
                nn.Dropout(tasa_abandono),
            )

        self.n_historial = 1
        self.beta = beta
        self.ventana = ventana

        self.norma_capa = norma_capa
        if self.norma_capa:
            self.norma = nn.BatchNorm1d(self.input_shape[0] * self.input_shape[-1])
        self.norma1 = nn.BatchNorm1d(self.n_historial * ventana * self.input_shape[-1])
        if self.beta > 0.0:
            self.norma2 = nn.BatchNorm1d(ventana * self.input_shape[-1])

        self.agregacion = nn.Linear(self.n_historial * ventana, ventana)
        self.abandono_t = nn.Dropout(tasa_abandono)

    def forward(self, x):
        # [tamaño_lote, num_caracteristicas, longitud_secuencia]
        if self.norma_capa:
            x = self.norma(torch.flatten(x, 1, -1)).reshape(x.shape)

        salida = torch.zeros_like(x)
        salida[:, :, :self.n_historial * self.ventana] = x[:, :, :self.n_historial * self.ventana].clone()
        for i in range(self.n_historial * self.ventana, self.input_shape[0], self.ventana):
            # entrada [tamaño_lote, num_caracteristicas, n_historial * ventana]
            entrada = salida[:, :, i - self.n_historial * self.ventana: i]
            # entrada [tamaño_lote, num_caracteristicas, n_historial * ventana]
            entrada = self.norma1(torch.flatten(entrada, 1, -1)).reshape(entrada.shape)
            # agregación
            # [tamaño_lote, num_caracteristicas, ventana]
            entrada = F.gelu(self.agregacion(entrada))  # n_historial * ventana -> ventana
            entrada = self.abandono_t(entrada)
            # entrada [tamaño_lote, num_caracteristicas, ventana]
            tmp = entrada + x[:, :, i: i + self.ventana]

            residuo = tmp

            # [tamaño_lote, num_caracteristicas, ventana]
            if self.beta > 0.0:
                tmp = self.norma2(torch.flatten(tmp, 1, -1)).reshape(tmp.shape)
                tmp = torch.transpose(tmp, 1, 2)
                # [tamaño_lote, ventana, num_caracteristicas]
                tmp = self.bloque_fc(tmp)
                tmp = torch.transpose(tmp, 1, 2)
            salida[:, :, i: i + self.ventana] = residuo + self.beta * tmp

        # [tamaño_lote, num_caracteristicas, longitud_secuencia]
        return salida

Síntesis Adaptativa Multi-Predictor (AMS)

  • Función principal: Utilizar un selector por compuerta para asignar dinámicamente pesos a predictores paralelos y sintetizar el pronóstico ponderado.
  • Ventaja principal: Enfocarse adaptativamente en el patrón temporal dominante, superando significativamente al promedio o al predictor único.
  • Código principal (fragmento):
class CompuertaTopK(nn.Module):
    def __init__(self, dim_entrada, num_expertos, k_superior=2, epsilon_ruido=1e-5):
        super(CompuertaTopK, self).__init__()
        self.compuerta = nn.Linear(dim_entrada, num_expertos)
        self.k_superior = k_superior
        self.epsilon_ruido = epsilon_ruido
        self.num_expertos = num_expertos
        self.w_ruido = nn.Parameter(torch.zeros(num_expertos, num_expertos), requires_grad=True)
        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)

    def descomposicion_tp(self, x, alfa=10):
        # x [tamaño_lote, longitud_secuencia]
        salida = torch.zeros_like(x)
        # [tamaño_lote]
        valor_k_esimo, _ = torch.kthvalue(x, self.num_expertos - self.k_superior + 1)
        # [tamaño_lote, num_experto]
        matriz_k_esima = valor_k_esimo.unsqueeze(1).expand(-1, self.num_expertos)
        mascara = x < matriz_k_esima
        x = self.softmax(x)
        salida[mascara] = alfa * torch.log(x[mascara] + 1)
        salida[~mascara] = alfa * (torch.exp(x[~mascara]) - 1)
        # Ablación MoE Disperso
        # salida[mascara] = 0
        # [tamaño_lote, longitud_secuencia]
        return salida

    def forward(self, x):
        # [tamaño_lote, longitud_secuencia]
        x = self.compuerta(x)
        logits_limpios = x
        # [tamaño_lote, num_expertos]

        if self.training:
            desviacion_ruido_crudo = x @ self.w_ruido
            desviacion_ruido = ((self.softplus(desviacion_ruido_crudo) + self.epsilon_ruido))
            logits_ruidosos = logits_limpios + (torch.randn_like(logits_limpios) * desviacion_ruido)
            logits = logits_ruidosos
        else:
            logits = logits_limpios

        logits = self.descomposicion_tp(logits)
        compuertas = self.softmax(logits)
        return compuertas


class AMS(nn.Module):
    def __init__(self, input_shape, longitud_pron, ff_dim=2048, tasa_abandono=0.2, coef_perdida=1.0, num_expertos=4, k_superior=2):
        super(AMS, self).__init__()
        # input_shape[0] = longitud_secuencia    input_shape[1] = num_caracteristicas
        self.num_expertos = num_expertos
        self.k_superior = k_superior
        self.longitud_pron = longitud_pron

        self.compuerta = CompuertaTopK(input_shape[0], num_expertos, k_superior)

        self.expertos = nn.ModuleList(
            [Experto(input_shape[0], longitud_pron, dim_oculto=ff_dim, tasa_abandono=tasa_abandono) for _ in range(num_expertos)])
        self.coef_perdida = coef_perdida
        assert (self.k_superior <= self.num_expertos)

    def cv_cuadrado(self, x):
        eps = 1e-10
        # si solo num_expertos = 1
        if x.shape[0] == 1:
            return torch.tensor([0], device=x.device, dtype=x.dtype)
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def forward(self, x, incrustacion_tiempo):
        # [tamaño_lote, num_caracteristicas, longitud_secuencia]
        tam_lote = x.shape[0]
        num_caract = x.shape[1]
        # [num_caracteristicas, tamaño_lote, longitud_secuencia]
        x = torch.transpose(x, 0, 1)
        incrustacion_tiempo = torch.transpose(incrustacion_tiempo, 0, 1)

        salida = torch.zeros(num_caract, tam_lote, self.longitud_pron).to(x.device)
        perdida = 0

        for i in range(num_caract):
            entrada = x[i]
            info_tiempo = incrustacion_tiempo[i]
            # x[i]  [tamaño_lote, longitud_secuencia]
            compuertas = self.compuerta(info_tiempo)

            # salidas_expertos [tamaño_lote, num_expertos, longitud_pron]
            salidas_expertos = torch.zeros(self.num_expertos, tam_lote, self.longitud_pron).to(x.device)

            for j in range(self.num_expertos):
                salidas_expertos[j, :, :] = self.expertos[j](entrada)
            salidas_expertos = torch.transpose(salidas_expertos, 0, 1)
            # compuertas [tamaño_lote, num_expertos, longitud_pron]
            compuertas = compuertas.unsqueeze(-1).expand(-1, -1, self.longitud_pron)
            # salida_lote [tamaño_lote, longitud_pron]
            salida_lote = (compuertas * salidas_expertos).sum(1)
            salida[i, :, :] = salida_lote

            importancia = compuertas.sum(0)
            perdida += self.coef_perdida * self.cv_cuadrado(importancia)

        # [num_caracteristicas, tamaño_lote, longitud_secuencia]
        salida = torch.transpose(salida, 0, 1)
        # [tamaño_lote, num_caracteristicas, longitud_secuencia]
        return salida, perdida

Etiquetas: Series Temporales Pronóstico Redes MLP Descomposición Multiescala Módulo DDI

Publicado el 6-8 09:32