Giter Club home page Giter Club logo

destillml's Introduction

Proyecto Final - Knowledge Distillation y técnicas de XAI para el aprendizaje

Knowledge destillation sobre imaganes naturales

  1. En principio se utilizo la red tutora VGG19, la red aprendiz era VGG19 con la mitad con la mitad de filtros convolucionales, sobre el 3% del dataset de ImageNet (1.4 millones de imagenes y 1000 clases).
  2. Surge el problema del gran cuello de botella de parametros en ambas redes, la gran necesidad de memoria y tiempo de entrenamiento por el tamaño de las redes y el tamaño del dataset a pesar de ser una fraccion.
  3. Se propone la utilizacion de redes que distribuyan mejor los parametros, como la familia de modelos ResNet y el dataset en principio completo de CIFAR-100 (60 mil imagenes y 100 clases).
  4. Para poder comparar resultados y hacer consistente la comparacion se aplicará el mismo esquema de entrenamiento para la red tutora y la red aprendiz, y se estudiarán técnicas para mejorar este baseline mediante knowledge destillation.
  5. Se aplican principalmente 3 técnicas de destilación de conocimiento.
    • Soft targets: Consiste en entrenar la red aprendiz con las salidas de la red tutora (logits), en lugar de las etiquetas originales.
    • Similarity preserving: Consiste en entrenar la red aprendiz con las salidas de la red tutora (logits) y las salidas de la red aprendiz, para que sean similares.
    • Attention transfer: Consiste en entrenar la red aprendiz con las salidas de las capas intermedias de la red tutora, para que la red aprendiz preste atención a las mismas características que la red tutora.
  6. En caso de resultados prometedores, se utilizará la técnica de mejor desempeño en un escenario de escasez de recursos, es decir, la red aprendiz con menos parametros y el dataset mas pequeño.
  7. Se compararán los resultados de las redes tutoras y aprendices con las métricas de accuracy y tiempo de entrenamiento además de técnias de XAI para entender el comportamiento de las redes como Grad-CAM y SHAP.
Debido a la cantidad de parámetros, se debieron considerar estos cambios debido a la capacidad computacional y la aplicabilidad del dataset al entrenar las diferentes arquitecturas de VGG.

Detalles de la impelementación

Conjunto de datos

Los datos de entrenamiento se componen principalmente de imágenes del conjunto de datos CIFAR-100. Para el conjunto de entrenamiento 45.000 y validación 5.000, se aplican transformaciones como volteo horizontal aleatorio, recorte aleatorio, ajuste de color, rotación aleatoria y recorte redimensionado aleatorio para aumentar la diversidad de los datos y mejorar la capacidad del modelo para generalizar. Además, se normalizan las imágenes utilizando la media y la desviación estándar proporcionadas para ImageNet (valor comúnmente usado por su gran diversidad de imagenes). Para el conjunto de prueba, se aplica una transformación estándar sin aumentación de datos. Se utilizan DataLoader para cargar los conjuntos de datos de entrenamiento, validación y prueba con el tamaño de lote especificado, y se configuran para el procesamiento paralelo y la asignación de memoria.

Entrenamiento de la red tutora y la red aprendiz

El experimento se enfoca en entrenar un modelo utilizando el optimizador SGD con una tasa de aprendizaje inicial de 0.5 y un momento de 0.9. Implementa un esquema de programación de la tasa de aprendizaje con un calentamiento inicial de 5 épocas lineal, seguido de una disminución cíclica basada en la función coseno hasta que se alcanza el límite de épocas definido. Se ajusta el peso decaimiento para no afectar a las capas de normalización por lotes. Durante el entrenamiento, se evalúa la precisión y la pérdida utilizando la entropía cruzada, con un suavizado de etiquetas de 0.1.

Entrenamiento de la red aprendiz con destilación de conocimiento

Se implementan tres técnicas de destilación de conocimiento: soft targets, similarity preserving y attention transfer. Para cada técnica, se entrena un modelo utilizando el optimizador SGD con una tasa de aprendizaje inicial de 0.5 y un momento de 0.9. Implementa un esquema de programación de la tasa de aprendizaje con un calentamiento inicial de 5 épocas lineal, seguido de una disminución cíclica basada en la función coseno hasta que se alcanza el límite de épocas definido. Se ajusta el peso decaimiento para no afectar a las capas de normalización por lotes. Durante el entrenamiento, se evalúa la precisión y la pérdida utilizando la entropía cruzada, sin suavizado de etiquetas.

Evaluación de la red tutora y la red aprendiz

Se evalúan los modelos entrenados utilizando el conjunto de prueba. Se calcula la precisión y la pérdida utilizando la entropía cruzada. Además, se aplican técnicas de XAI para comprender el comportamiento de los modelos. - Grad-CAM: Genera mapas de activación de clase para visualizar las regiones importantes de las imágenes. - LIME: Explica las predicciones de los modelos utilizando un modelo localmente interpretable. - SHAP: Explica las predicciones de los modelos utilizando valores Shapley.

Resultados

Se comparan los resultados de los modelos entrenados utilizando las métricas de precisión y pérdida. Además, se analizan las visualizaciones generadas por las técnicas de XAI para comprender el comportamiento de los modelos.

Referencias

Reproducibilidad

  1. Clonar el repositorio

    git clone https://github.com/M4thinking/DestillML.git && cd DestillML
  2. Crear ambiente virtual, activar, updatear pip e instalar dependencias:

    python -m venv env
    source ./env/bin/activate
    python -m pip install --upgrade pip
    pip install -r requirements.txt
  3. Ejecutar dataset a utilizar (cifar10, cifar100, imagenet):

    python dataset.py --dataset cifar100
  4. Entrenar red tutora:

    python trainer.py --dataset cifar100 --architecture ResNet101 --epochs 600 --batch-size 128

    Además, puedes utilizar --show_versions para ver si existen más modelos entrenados bajo la misma configuración de dataset y arquitectura. Con --version {version} puedes continuar el entrenamiento de un modelo existente entregando su respectiva versión.

    Por último para ver las principales métricas de entrenamiento y validación, además de guardar el onnx del modelo, puedes utilizar

    python metrics.py --dataset cifar100 --architecture ResNet101 --select_version 0
  5. Entrenar la red aprendiz de dos formas:

    1. Entrenar red aprendiz como modelo base igual a la red tutora:

      python trainer.py --dataset cifar100 --architecture ResNet18 --epochs 600 --batch-size 128
    2. Entrenar red aprendiz con destilación de conocimiento, para esto, primero debe guardar el modelo onnx de la red tutora en la carpeta de pretrained_models (puede usar metrics.py y mover el archivo onnx desde el checkpoint del experimento a la carpeta pretrained_models). Luego, puede entrenar la red aprendiz de la siguiente manera:

      python destiller.py --dataset cifar100 --student_architecture ResNet18 --epochs 600 --batch-size 128 --distillation soft_targets --teacher_architecture ResNet101

      Igual que antes, puedes utilizar --show_versions para ver si existen más modelos entrenados bajo la misma configuración de dataset y arquitectura. Con --version {version} puedes continuar el entrenamiento de un modelo existente entregando su respectiva versión.

destillml's People

Contributors

m4thinking avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.