[inventory_dynamics] Update code to JAX and latest style guide#623
[inventory_dynamics] Update code to JAX and latest style guide#623
Conversation
|
📖 Netlify Preview Ready! Preview URL: https://pr-623--sunny-cactus-210e3e.netlify.app (2b7c4fe) 📚 Changed Lecture Pages: inventory_dynamics |
There was a problem hiding this comment.
Pull Request Overview
This PR updates the inventory dynamics lecture to use JAX instead of numba and numpy, modernizing the codebase and improving performance. The changes align with the latest style guide and fix minor formatting issues.
Key changes:
- Replaces numba-based
Firmclass with JAX-based functions usingNamedTuple - Converts simulation functions to use JAX's functional programming paradigm with
jax.lax.scanandjax.vmap - Updates all numpy operations to JAX numpy equivalents
|
📖 Netlify Preview Ready! Preview URL: https://pr-623--sunny-cactus-210e3e.netlify.app (54d1cac) 📚 Changed Lecture Pages: inventory_dynamics |
HumphreyYang
left a comment
There was a problem hiding this comment.
Many thanks @kp992! These are really nice! Just two very minor observations.
I think we should remove the grid in the figures to maintain consistency in the style:
for ax in axes:
ax.grid(alpha=0.4)
| # Generate independent random keys for each firm | ||
| firm_keys = random.split(key, (num_firms, sim_length)) | ||
| # Run simulation for all firms | ||
| restock_indicators = vectorized_simulate(firm_keys) | ||
| # Compute frequency (fraction of firms that restocked > 1 times) | ||
| frequency = jnp.mean(restock_indicators) | ||
| return frequency |
There was a problem hiding this comment.
I think it might look better if we add some spaces in between:
| # Generate independent random keys for each firm | |
| firm_keys = random.split(key, (num_firms, sim_length)) | |
| # Run simulation for all firms | |
| restock_indicators = vectorized_simulate(firm_keys) | |
| # Compute frequency (fraction of firms that restocked > 1 times) | |
| frequency = jnp.mean(restock_indicators) | |
| return frequency | |
| # Generate independent random keys for each firm | |
| firm_keys = random.split(key, (num_firms, sim_length)) | |
| # Run simulation for all firms | |
| restock_indicators = vectorized_simulate(firm_keys) | |
| # Compute frequency (fraction of firms that restocked > 1 times) | |
| frequency = jnp.mean(restock_indicators) | |
| return frequency |
This PR: