Introducción
Desplegar modelos de inteligencia artificial en producción puede presentar desafíos complejos, desde el entrenamiento hasta la puesta en servicio. ResNet18 es una arquitectura de red neuronal convolucional conocida por su equilibrio entre rendimiento y eficiencia, ideal para tareas de clasificación de imágenes con recursos limitados. Este tutorial explica paso a paso cómo entrenar, optimizar y desplegar un modelo ResNet18 en un entorno cloud, sin requerir conocimientos avanzados en IA.
- Configuración del entorno y preparación de datos
1.1 Inicialización del entorno computacional
Comienza instalando PyTorch y sus dependencias. Asegúrate de tener acceso a una GPU para acelerar el proceso:
import torch
# Verificar disponibilidad de GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Dispositivo utilizado: {device}")
1.2 Carga y preprocesamiento del conjunto de datos
Utilizaremos el conjunto de datos CIFAR-10 para este ejemplo. Aplicamos transformaciones estándar para adaptar las imágenes al formato requerido por ResNet:
from torchvision import datasets, transforms
# Definir pipeline de preprocesamiento
preprocessing = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Cargar datos de entrenamiento y prueba
train_set = datasets.CIFAR10(root="./datasets", train=True, download=True, transform=preprocessing)
test_set = datasets.CIFAR10(root="./datasets", train=False, transform=preprocessing)
# Configurar dataloaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
- Entrenamiento y evaluación del modelo
2.1 Arquitectura del modelo con ajustes
Cargamos ResNet18 con pesos preentrenados y modificamos la capa final para adaptarla a 10 clases:
from torchvision import models
import torch.nn as nn
# Cargar modelo base
net = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Modificar capa de salida
num_classes = 10
input_features = net.fc.in_features
net.fc = nn.Linear(input_features, num_classes)
# Mover modelo a GPU si está disponible
net.to(device)
2.2 Proceso de entrenamiento
Configuramos el optimizador y la función de pérdida, luego ejecutamos el bucle de entrenamiento:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
num_epochs = 15
for epoch in range(num_epochs):
net.train()
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# Evaluación en conjunto de prueba
net.eval()
correct_predictions = 0
total_samples = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
_, predicted = torch.max(outputs, 1)
total_samples += targets.size(0)
correct_predictions += (predicted == targets).sum().item()
accuracy = 100 * correct_predictions / total_samples
print(f"Epoch {epoch+1}/{num_epochs} - Pérdida: {epoch_loss/len(train_loader):.4f} - Precisión: {accuracy:.2f}%")
2.3 Persistencia del modelo entrenado
Guardamos los parámetros del modelo para su uso posterior:
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': accuracy
}, "resnet18_custom.pth")
- Optimización del modelo y creación de servicio
3.1 Conversión a formato TorchScript
Transformamos el modelo PyTorch a TorchScript para mejorar el rendimiento en producción:
checkpoint = torch.load("resnet18_custom.pth")
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# Crear ejemplo de entrada para trazado
sample_tensor = torch.randn(1, 3, 224, 224).to(device)
traced_model = torch.jit.trace(net, sample_tensor)
# Guardar modelo optimizado
traced_model.save("optimized_resnet18.pt")
3.2 Desarrollo de API con Flask
Construimos un endpoint REST para realizar inferencias:
from flask import Flask, request, jsonify
from PIL import Image
import io
application = Flask(__name__)
# Cargar modelo TorchScript
inference_model = torch.jit.load("optimized_resnet18.pt")
inference_model.eval()
CLASS_LABELS = [
"avión", "automóvil", "ave", "gato", "ciervo",
"perro", "rana", "caballo", "barco", "camión"
]
def preprocess_image(image_data):
transform_pipeline = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image_data))
return transform_pipeline(image).unsqueeze(0)
@application.route('/classify', methods=['POST'])
def classify_image():
if 'image' not in request.files:
return jsonify({"error": "No se proporcionó imagen"}), 400
file = request.files['image']
image_bytes = file.read()
input_tensor = preprocess_image(image_bytes)
with torch.no_grad():
output = inference_model(input_tensor.to(device))
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_class = torch.argmax(probabilities).item()
return jsonify({
"predicted_class": CLASS_LABELS[top_class],
"confidence": float(probabilities[top_class])
})
if __name__ == '__main__':
application.run(host='0.0.0.0', port=8080, debug=False)
- Implementación en entorno cloud
4.1 Empaquetado de la aplicación
Preparamos los archivos necesarios para el despliegue:
- optimized_resnet18.pt: modelo TorchScript
- app.py: aplicación Flask
- requirements.txt: dependencias del proyecto
Contenido de requirements.txt:
flask>=2.3.0
torch>=2.0.0
torchvision>=0.15.0
Pillow>=10.0.0
4.2 Despliegue y verificación
Después de subir los archivos a tu proveedor cloud preferido, ejecuta el servicio y realiza una prueba:
# Ejemplo de prueba con curl
curl -X POST -F "image=@imagen_prueba.jpg" http://tu-servidor.com:8080/classify
Respuesta esperada:
{
"predicted_class": "gato",
"confidence": 0.9234
}
- Mejoras de rendimiento y escalabilidad
5.1 Técnicas de optimización
- Cuantización del modelo: Reducir precisión de 32-bit a 16-bit para acelerar inferencia
- Procesamiento por lotes: Modificar API para aceptar múltiples imágenes simultáneamente
- Caché de modelos: Implementar caché para modelos TorchScript ya cargados
5.2 Solución a problemas comunes
- Errores de memoria: Reducir tamaño de batch o implementar gradient checkpointing
- Alta latencia: Convertir modelo a formato ONNX para mayor compatibilidad
- Fallos de conexión: Configurar timeout y reintentos en la API
5.3 Estrategias de escalado
Para entornos de producción, considera:
- Contenedorización con Docker para despliegue consistente
- Orquestación con Kubernetes para autoescalado
- Integración con sistemas de monitoreo como Prometheus
- Implementación de balanceadores de carga para distribuir tráfico