Arquitectura y Despliegue de Modelos Semisupervisados para Segmentación de Imágenes Médicas

Introducción al Marco de Trabajo

La segmentación de imágenes médicas mediante aprendizaje profundo enfrenta el desafío crítico de la escasez de anotaciones manuales expertas. El marco SSL4MIS aborda esta limitación mediante la implementación de técnicas de aprendizaje semisupervisado (SSL), permitiendo el entrenamiento de modelos robustos utilizando un subconjunto mínimo de datos etiquetados junto con grandes volúmenes de datos no etiquetados.

Integración de Algoritmos y Arquitecturas

El framework soporta múltiples paradigmas de consistencia y pseudo-etiquetado. Entre los métodos implementados se encuentran la Supervisión Pseudo-Cruzada (CPS), Mean Teacher y FixMatch. Estos algoritmos operan bajo la premisa de regularizar la salida del modelo mediante perturbaciones en los datos de entrada o en la propia arquitectura de la red.

En cuanto a los extractores de características, se proporciona soporte nativo para arquitecturas convolucionales como U-Net y sus variantes 3D, así como transformadores visionarios como Swin Transformer, adaptados para procesar volúmenes médicos de alta resolución (CT y MRI).

Despliegue del Entorno y Preparación de Datos

Para garantizar la reproducibilidad, la configuración del entorno se gestiona mediante scripts de automatización que definen variables de ruta y nombres de entornos virtuales.

#!/bin/bash
REPO_SOURCE="https://gitcode.com/gh_mirrors/ss/SSL4MIS.git"
WORKSPACE="medical_ssl_workspace"
CONDA_ENV="med_seg_env"

git clone ${REPO_SOURCE} ${WORKSPACE}
cd ${WORKSPACE}

conda env create -f environment.yml -n ${CONDA_ENV}
source activate ${CONDA_ENV}

Los datos deben estructurarse en directorios específicos (data/ACDC/, data/BraTS2019/), acompañados de archivos de manifiesto (.list) que dividen los conjuntos de entrenamiento y validación. Los scripts de preprocesamiento en code/dataloaders/ normalizan los vóxeles y ajustan las dimensiones espaciales.

Ejecución de Entrenamiento y Evaluación

El pipeline de entrenamiento expone scripts modulares para diferentes estrategias SSL. En lugar de ejecutar directamente los scripts de shell, se puede implementar un wrapper en Python para gestionar los argumentos y el registro de métricas durante la inferencia.

import argparse
import subprocess
import sys

def execute_evaluation_pipeline(checkpoint_path, spatial_dim="3D"):
    evaluation_cmd = [
        sys.executable, f"test_{spatial_dim}.py",
        "--weights", checkpoint_path,
        "--output_dir", "./eval_results",
        "--compute_hd95"
    ]
    
    process = subprocess.run(evaluation_cmd, capture_output=True, text=True)
    if process.returncode != 0:
        print(f"Error en la evaluación: {process.stderr}")
        sys.exit(1)
    print(process.stdout)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pipeline de evaluación SSL")
    parser.add_argument("--ckpt", type=str, required=True, help="Ruta al modelo entrenado")
    parser.add_argument("--dim", type=str, default="3D", choices=["2D", "3D"])
    args = parser.parse_args()
    
    execute_evaluation_pipeline(args.ckpt, args.dim)

Optimización de Funciones de Pérdida y Aumento de Datos

Para mejorar la convergencia en regiones de baja frecuencia o bordes difusos, es recomendable combinar la entropía cruzada estándar con la pérdida de Dice. El siguiente fragmento muestra cómo estructurar una función de pérdida combinada dentro del módulo de utilidades del proyecto.

import torch
import torch.nn as nn
import torch.nn.functional as F

class CombinedSegmentationLoss(nn.Module):
    def __init__(self, ce_weight=0.5, dice_weight=0.5, smooth=1e-5):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.smooth = smooth
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, predictions, targets):
        ce_loss = self.cross_entropy(predictions, targets)
        
        probs = F.softmax(predictions, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=probs.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (probs * targets_one_hot).sum(dim=(2, 3))
        dice_loss = 1 - ((2. * intersection + self.smooth) / 
                         (probs.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3)) + self.smooth))
        
        total_loss = (self.ce_weight * ce_loss) + (self.dice_weight * dice_loss.mean())
        return total_loss

La selección de la estrategia de aumento de datos, como CTAugment, es crucial para generar las vistas perturbadas necesarias en los métodos de consistencia, asegurando que el modelo aprenda representaciones invariantes a las variaciones clínicas típicas.

Etiquetas: semi-supervised-learning medical-image-segmentation ssl4mis PyTorch u-net

Publicado el 6-23 22:07