Google Jax puede superar a Numpy en la investigación del machine learning

UN

La diferenciación automática puede marcar una gran diferencia en el éxito de un modelo de aprendizaje profundo, reduciendo el tiempo de desarrollo para iterar modelos y experimentos.

Anteriormente, los programadores tenían que diseñar sus propios gradientes, lo que hacía que el modelo fuera vulnerable a errores, además de consumir mucho tiempo, lo que con el tiempo resultó ser desastroso. Para realizar un seguimiento de los gradientes en una red neuronal, se utilizan bibliotecas como TensorFlow y PyTorch . Estas bibliotecas ayudan a desarrollar funcionalidades normales, pero para desarrollar un modelo que está fuera de su alcance, no son suficientes. Autograd es una biblioteca destinada a la diferenciación automática de Python nativoy código NumPy.


Banner_frasco-suscripcion-800x250

JAX, la versión improvisada de Autograd, puede combinar aceleración de hardware y diferenciación automática con XLA (álgebra lineal acelerada), un compilador específico de dominio para álgebra lineal que puede acelerar los modelos de flujo de tensor. En resumen, JAX es una biblioteca de Python con funciones NumPy que se utilizan para eliminar operaciones triviales de aprendizaje automático.

¿Por qué debería usar JAX?

Además de admitir la diferenciación automática, que principalmente tiene el fuerte para el aprendizaje profundo, JAX puede mejorar enormemente la velocidad, una funcionalidad única por la que JAX es apreciado por muchos desarrolladores.

Como las operaciones de JAX se basan en XLA, es posible compilar a una velocidad más rápida de lo normal, es decir, alrededor de 7,3 veces más rápido con un entrenamiento normal y una velocidad acelerada de 12 veces a largo plazo. La función de compilación JIT (Just In Time) de JAX ayuda aún más a mejorar su velocidad al agregar un decorador de funciones simple. JAX ayuda enormemente a los desarrolladores a reducir la redundancia a través de la vectorización.

El proceso de aprendizaje automático

Comprende varias iteraciones, en las que se utiliza una sola función para modelar una gran cantidad de conjuntos de datos. La vectorización automática que ofrece JAX a través de la transformación vmap permite el paralelismo de datos mediante la transformación pmap. Por lo general, JAX se considera un marco alternativo de aprendizaje profundo, sin embargo, sus aplicaciones van más allá de las funcionalidades de la biblioteca. Lino, haiku, y elegy son las bibliotecas que se construyen sobre JAX para procesos de aprendizaje profundo. En particular, las hessianas realizan una optimización de orden superior, en la que JAX es bueno computando, todo gracias a XLA.

JXA contra Numpy:

Como JXA es altamente compatible con GPU, tiene compatibilidad inherente con CPU, a diferencia de Numpy, que solo es compatible con CPU. JAX tiene una API similar a Numpy, por lo que puede compilar automáticamente el código directamente en aceleradores como GPU y TPU, lo que hace que el proceso sea perfecto. Esto significa que un código escrito en la sintaxis de Numpy se puede ejecutar tanto en CPU como en GPU sin fallas. A pesar de tener construcciones especializadas, JAX se encuentra en un nivel más bajo con un nivel de control más bajo que el aprendizaje profundo, lo que lo convierte en un reemplazo perfecto para NumPy y, debido a su estructura básica, puede usarse para todo tipo de desarrollo además del aprendizaje profundo. . Con todo, JAX puede considerarse una versión aumentada de Numpy para realizar las funciones antes mencionadas, con la versión numpy de JAX dirigida como Jax.numPy, y JAX es casi numpy excepto que JAX puede ejecutar código en aceleradores.

Conoce más aquí

Banner_azules
Reciba las últimas noticias de la industria en su casilla:

Suscribirse ✉