Guía práctica: exponer ccmusic-database como API REST con Flask para consumo frontend

Origen y funcionamiento de ccmusic-database

ccmusic-database es un sistema de clasificación de géneros musicales que aprovecha modelos de visión por computadora preentrenados para analizar representaciones visuales del audio. En lugar de entrenar una red desde cero sobre señales sonoras, la herramienta convierte el audio en una imagen espectro-temporal mediante la transformada CQT (Constant-Q Transform) y la pasa por una VGG19_BN, una arquitectura de clasificación visual ya consolidada.

Esta transferencia entre modalidades permite identificar estilos musicales a partir de patrones de textura, bloques rítmicos y distribución energética prseentes en el espectrograma. El resultado es un modelo capaz de clasificar 16 géneros distintos sin requerir grandes volúmenes de audio etiquetado. La interfaz Gradio es solo una demostración; el valor real reside en la cadena de inferencia, compuesta por la carga del modelo, la extracción de características CQT y el envoltorio de predicción.

Por qué convertir el demo en una API REST

La interfaz Gradio resulta útil para pruebas rápidas, pero presenta limitaciones importantes en proyectos reales:

  • Acoplamiento con la interfaz: no se integra directamente en paneles de administración hechos con Vue, React o aplicaciones móviles.
  • Protocolo propietario: utiliza WebSocket y mensajes personalizados, lo que obliga a implementar un cliente específico.
  • Complicaciones de despliegue: su servidor embebido dificulta la integración con Nginx, Gunicorn o Kubernetes.

Una API REST estándar resuelve estos problemas: acepta peticiones HTTP, devuelve JSON y puede exponerse bajo un contrato estable. El frontend consume el servicio con fetch o axios, mientras que el backend se enriquece con autenticación, limitación de tasa, logs y monitoreo.

Migración del demo Gradio a un servicio Flask

Enfoque de desacoplamiento

El archivo app.py original combina tres elementos: la lógica del modelo, la transformación CQT y la capa de interfaz. El objetivo es conservar los dos primeros y reemplazar el tercero por una capa HTTP ligera. De este modo se mantiene intacto el núcleo de inferencia y se añade un envoltorio web estándar.

Estructura del proyecto

Se recomienda crear un directorio separado para la API, dejando el código original como dependencia de solo lectura:

ccmusic-api/
├── api/
│   ├── __init__.py
│   ├── inference.py    # carga del modelo y predicción
│   └── endpoints.py    # rutas Flask
├── music_genre/        # código original de ccmusic-database
│   ├── app.py
│   ├── vgg19_bn_cqt/
│   └── plot.py
├── requirements.txt
├── config.py
└── main.py

Implementación del modelo

El siguiente módulo encapsula la carga del modelo, la generación del espectrograma CQT y la inferencia. Se mantiene la coherencia con los parámetros originales a la vez que se reorganiza el código y se renombran variables y métodos:

# api/inference.py
import io
import torch
import librosa
import numpy as np
from PIL import Image
from torchvision import transforms
from music_genre.vgg19_bn_cqt.model import VGG19_BN_CQT


class GenrePredictor:
    def __init__(self, weights_path: str = "./music_genre/vgg19_bn_cqt/save.pt"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = VGG19_BN_CQT(num_classes=16).to(self.device)
        self.net.load_state_dict(torch.load(weights_path, map_location=self.device))
        self.net.eval()

        self.sample_rate = 22050
        self.bin_count = 84
        self.low_freq = 27.5
        self.hop_length = 512

        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _build_cqt_image(self, waveform: np.ndarray) -> Image.Image:
        spectrogram = np.abs(librosa.cqt(
            y=waveform,
            sr=self.sample_rate,
            fmin=self.low_freq,
            n_bins=self.bin_count,
            hop_length=self.hop_length
        ))

        spec_min = spectrogram.min()
        spec_max = spectrogram.max()
        normalized = (spectrogram - spec_min) / (spec_max - spec_min + 1e-8) * 255
        gray = normalized.astype(np.uint8)
        rgb = np.stack([gray] * 3, axis=-1)

        return Image.fromarray(rgb)

    def classify(self, audio_bytes: bytes) -> dict:
        try:
            waveform, _ = librosa.load(
                io.BytesIO(audio_bytes),
                sr=self.sample_rate,
                mono=True
            )
        except Exception as exc:
            raise ValueError(f"No se pudo decodificar el audio: {exc}")

        max_samples = self.sample_rate * 30
        if waveform.shape[0] > max_samples:
            waveform = waveform[:max_samples]

        cqt_image = self._build_cqt_image(waveform)
        tensor = self.preprocess(cqt_image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            logits = self.net(tensor)
            probabilities = torch.nn.functional.softmax(logits, dim=1)[0]

        labels = [
            "Symphony", "Opera", "Solo", "Chamber", "Pop vocal ballad",
            "Adult contemporary", "Teen pop", "Contemporary dance pop",
            "Dance pop", "Classic indie pop", "Chamber cabaret & art pop",
            "Soul / R&B", "Adult alternative rock", "Uplifting anthemic rock",
            "Soft rock", "Acoustic pop"
        ]

        top_indices = probabilities.argsort(descending=True)[:5]
        predictions = [
            {"genre": labels[idx], "score": float(probabilities[idx])}
            for idx in top_indices
        ]

        return {"predictions": predictions}

Definición de rutas

La capa HTTP recibe el archivo de audio, valida su formato y delega la inferencia al predictor:

# api/endpoints.py
from flask import Blueprint, request, jsonify
from api.inference import GenrePredictor

api_v1 = Blueprint("api_v1", __name__)
estimator = GenrePredictor()

ALLOWED_EXTENSIONS = {".mp3", ".wav"}


@api_v1.route("/predict", methods=["POST"])
def predict_genre():
    if "file" not in request.files:
        return jsonify({"error": "El campo 'file' es obligatorio"}), 400

    uploaded = request.files["file"]
    if uploaded.filename == "":
        return jsonify({"error": "No se seleccionó ningún archivo"}), 400

    if not any(uploaded.filename.lower().endswith(ext) for ext in ALLOWED_EXTENSIONS):
        return jsonify({"error": "Solo se aceptan archivos MP3 o WAV"}), 400

    try:
        audio_payload = uploaded.read()
        result = estimator.classify(audio_payload)
        return jsonify(result)
    except ValueError as exc:
        return jsonify({"error": str(exc)}), 400
    except Exception:
        return jsonify({"error": "Error interno del servidor"}), 500

Punto de entrada de la aplicación

# main.py
import os
from flask import Flask
from api.endpoints import api_v1

app = Flask(__name__)
app.register_blueprint(api_v1, url_prefix="/api/v1")


if __name__ == "__main__":
    port = int(os.environ.get("PORT", 5000))
    app.run(host="0.0.0.0", port=port)

Dependencias y arrranque

El archivo requirements.txt se mantiene minimalista:

Flask==2.3.3
torch==2.0.1
torchaudio==2.0.2
librosa==0.10.1
numpy==1.24.3
Pillow==10.0.0

Para ejecutar el servicio:

pip install -r requirements.txt
python main.py

La API quedará disponible en http://localhost:5000/api/v1/predict.

Consumo desde el frontend

Desde cualquier aplicación web moderna se puede llamar al endpoint con fetch. El siguiente ejemplo funciona en un navegador sin frameworks adicionales:


<html lang="es">
<head>
  <meta charset="UTF-8">
  <title>Clasificación de género musical</title>
</head>
<body>
  <input type="file" id="audioInput" accept=".mp3,.wav">
  <button id="analyzeBtn">Clasificar</button>
  <div id="output"></div>

  <script>
    document.getElementById("analyzeBtn").addEventListener("click", async () => {
      const file = document.getElementById("audioInput").files[0];
      if (!file) return;

      const formData = new FormData();
      formData.append("file", file);

      try {
        const response = await fetch("http://localhost:5000/api/v1/predict", {
          method: "POST",
          body: formData
        });
        const data = await response.json();
        const container = document.getElementById("output");

        if (data.error) {
          container.textContent = "Error: " + data.error;
        } else {
          const first = data.predictions[0];
          const items = data.predictions
            .map(p => `${p.genre} (${(p.score * 100).toFixed(1)}%)`)
            .join(" | ");

          container.innerHTML = `
            <p><strong>Género principal:</strong> ${first.genre} (${(first.score * 100).toFixed(1)}%)</p>
            <p><strong>Top 5:</strong> ${items}</p>
          `;
        }
      } catch (err) {
        container.textContent = "Fallo de conexión: " + err.message;
      }
    });
  </script>
</body>
</html>

También es posible probar el servicio desde la línea de comandos:

curl -X POST http://localhost:5000/api/v1/predict \
  -F "file=@./ejemplo.mp3"

Optimización y consideraciones productivas

Rendimiento y estabilidad

  • GPU: el modelo se carga automáticamente en CUDA si está disponible.
  • Procesamiento por lotes: se puede añadir un endpoint /predict/batch para recibir múltiples archivos.
  • Control de tamaño: limitar el cuerpo de la petición a 50 MB para evitar bloqueos.
  • Cache del modelo: la instancia GenrePredictor se mantiene viva durante todo el ciclo de vida de la aplicación, evitando cargas repetidas del archivo de pesos.

Patrones de despliegue

Escenario Recomendación Comentario
Desarrollo local python main.py Arranque rápido y validación inmediata.
Producción pequeña Gunicorn + Nginx gunicorn -w 4 -b 0.0.0.0:5000 main:app
Contenedores Docker + Docker Compose Imagen base pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime.
Alta disponibilidad Kubernetes + HPA Escalar según CPU/memoria y exponer un endpoint de salud.

Seguridad y observabilidad

Se recomienda añadir al menos un endpoint de salud para monitoreo y orquestadores:

@api_v1.route("/healthz")
def health_check():
    return jsonify({"status": "ok", "model_loaded": True})

Además, conviene registrar logs estructurados con el nombre del archivo, el tiempo de procesamiento y la predicción principle. La validación de entrada debe rechazar formatos no soportados y, si es necesario, auditar la frecuencia de muestreo para evitar distorsiones en el CQT.

Etiquetas: Flask PyTorch VGG19 CQT REST API

Publicado el 7-5 03:44