✅ JAX is Autograd and XLA, brought together for high-performance machine learning research.
✅ JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
✅ Flax is a neural network library and ecosystem for JAX designed for flexibility.
✅ Flax delivers an end-to-end and flexible user experience for researchers who use JAX with neural networks.
- GitHub : https://github.com/deepmind/optax
✅ Optax is a gradient processing and optimization library for JAX.
✅ Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.