Welcome to the JAX & FLAX Demos repository! This collection of notebooks is designed to take you on a journey through the powerful capabilities of JAX, from basic linear regression to distributed neural network training.
Whether you're a seasoned researcher or a curious learner, dive in to see how JAX's composable transformations can supercharge your machine learning workflows! β‘
File: Linear_Regression___JAX.ipynb
Start here! This notebook covers the fundamentals:
- Data Generation: Creating synthetic data for regression.
-
Model Definition: Building a simple linear model
$y = wx + b$ . -
JAX Magic: Using
jax.gradfor automatic differentiation andjax.jitfor Just-In-Time compilation. - Training Loop: A manual training loop to optimize parameters using gradient descent.
File: MLP_for_MNIST___JAX.ipynb
Level up to deep learning with the classic MNIST digit classification task:
- Neural Network: Implementing a Multi-Layer Perceptron (MLP) from scratch.
- Vectorization: Leveraging
jax.vmapfor efficient batched predictions without manual looping. - PyTree Handling: Managing model parameters (weights & biases) as PyTrees.
- Training: Optimized training loop with accuracy tracking.
File: Distributed_Neural_Network_Training___JAX.ipynb
Unlock the power of parallelism! This advanced notebook demonstrates:
- Parallelism (pmap): Using
jax.pmapto distribute computation across multiple devices (TPUs/GPUs). - Data Sharding: Splitting data efficiently for data-parallel training.
- Gradient Synchronization: Syncing gradients across devices using
jax.lax.pmean. - Advanced Autodiff: Cool tricks like freezing layers with
stop_gradientand computing per-sample gradients.
JAX is not just another Deep Learning framework; it's NumPy on steroids! πͺ
- Autograd: Differentiate through native Python and NumPy code.
- XLA Compilation:
jitcompiles your code to XLA (Accelerated Linear Algebra) for blazing speed. - Vectorization:
vmapautomatically vectorizes your functions. - Parallelism:
pmapmakes multi-device training a breeze.
To run these notebooks, you'll need a Python environment with JAX installed.
pip install jax jaxlib matplotlib tensorflow-datasets numpy(Note: For GPU/TPU support, please refer to the official JAX installation guide.)
- Clone this repository.
- Launch Jupyter Notebook or JupyterLab.
- Open any of the
.ipynbfiles and run the cells!
Got a cool JAX trick or a new demo idea? Contributions are welcome! Feel free to open an issue or submit a pull request.
Happy JAX-ing! π§ͺβ¨