Skip to content

Abdo-Mohammed-10/Demos-JAX-FLAX

Repository files navigation

πŸš€ JAX & FLAX Demos

Welcome to the JAX & FLAX Demos repository! This collection of notebooks is designed to take you on a journey through the powerful capabilities of JAX, from basic linear regression to distributed neural network training.

Whether you're a seasoned researcher or a curious learner, dive in to see how JAX's composable transformations can supercharge your machine learning workflows! ⚑


πŸ“š What's Inside?

1. πŸ“ˆ Linear Regression with JAX

File: Linear_Regression___JAX.ipynb

Start here! This notebook covers the fundamentals:

  • Data Generation: Creating synthetic data for regression.
  • Model Definition: Building a simple linear model $y = wx + b$.
  • JAX Magic: Using jax.grad for automatic differentiation and jax.jit for Just-In-Time compilation.
  • Training Loop: A manual training loop to optimize parameters using gradient descent.

2. πŸ”’ MLP for MNIST

File: MLP_for_MNIST___JAX.ipynb

Level up to deep learning with the classic MNIST digit classification task:

  • Neural Network: Implementing a Multi-Layer Perceptron (MLP) from scratch.
  • Vectorization: Leveraging jax.vmap for efficient batched predictions without manual looping.
  • PyTree Handling: Managing model parameters (weights & biases) as PyTrees.
  • Training: Optimized training loop with accuracy tracking.

3. 🌐 Distributed Neural Network Training

File: Distributed_Neural_Network_Training___JAX.ipynb

Unlock the power of parallelism! This advanced notebook demonstrates:

  • Parallelism (pmap): Using jax.pmap to distribute computation across multiple devices (TPUs/GPUs).
  • Data Sharding: Splitting data efficiently for data-parallel training.
  • Gradient Synchronization: Syncing gradients across devices using jax.lax.pmean.
  • Advanced Autodiff: Cool tricks like freezing layers with stop_gradient and computing per-sample gradients.

⚑ Why JAX?

JAX is not just another Deep Learning framework; it's NumPy on steroids! πŸ’ͺ

  • Autograd: Differentiate through native Python and NumPy code.
  • XLA Compilation: jit compiles your code to XLA (Accelerated Linear Algebra) for blazing speed.
  • Vectorization: vmap automatically vectorizes your functions.
  • Parallelism: pmap makes multi-device training a breeze.

πŸ› οΈ Getting Started

To run these notebooks, you'll need a Python environment with JAX installed.

Installation

pip install jax jaxlib matplotlib tensorflow-datasets numpy

(Note: For GPU/TPU support, please refer to the official JAX installation guide.)

Usage

  1. Clone this repository.
  2. Launch Jupyter Notebook or JupyterLab.
  3. Open any of the .ipynb files and run the cells!

🀝 Contributing

Got a cool JAX trick or a new demo idea? Contributions are welcome! Feel free to open an issue or submit a pull request.

Happy JAX-ing! πŸ§ͺ✨

About

A hands-on collection of JAX & Flax notebooks for mastering High-Performance ML ⚑🧠 Covers Linear Regression, MLP on MNIST, and advanced distributed training using jit, vmap, and pmap πŸš€πŸ“Š

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors