https://github.com/google/jax
- HEAD
- refs/heads/LenaMartens-patch-1
- refs/heads/add-tpu-core-count
- refs/heads/adx3
- refs/heads/apaszke-debug-gpu
- refs/heads/array_tutorial
- refs/heads/avals-with-names
- refs/heads/bazel_tpu
- refs/heads/cache_log
- refs/heads/calltf_test
- refs/heads/ci_v3-8
- refs/heads/cse
- refs/heads/debug-nans-test-block-async
- refs/heads/dependabot/pip/matplotlib-3.9.0
- refs/heads/dependabot/pip/pluggy-1.5.0
- refs/heads/dependabot/pip/pytest-xdist-3.6.1
- refs/heads/dependabot/pip/setuptools-70.0.0
- refs/heads/dependabot/pip/zipp-3.18.2
- refs/heads/disable_test
- refs/heads/dynamic-experiments
- refs/heads/dynamic-scoping
- refs/heads/effect-types
- refs/heads/fix
- refs/heads/fix-notebooks
- refs/heads/fix-notebooks2
- refs/heads/fix_cloud_tpu_check
- refs/heads/fixed-point
- refs/heads/furo
- refs/heads/gda1
- refs/heads/gda2
- refs/heads/gmap
- refs/heads/gnecula-patch-1
- refs/heads/gpu-determinism-note
- refs/heads/hmm-example
- refs/heads/hoist-consts
- refs/heads/ijit
- refs/heads/index
- refs/heads/initial-state
- refs/heads/initial-style-autodidax
- refs/heads/issue2263
- refs/heads/issue3040
- refs/heads/issue3285
- refs/heads/issue3620
- refs/heads/issue768
- refs/heads/jax-attrs
- refs/heads/jax_release_0.3.6
- refs/heads/jb/broadcast-tie-in
- refs/heads/jb/debug-strong-zeros
- refs/heads/jb/partial-eval-scan
- refs/heads/jb/refs
- refs/heads/jb/scan-consts
- refs/heads/jb/scan-consts-wip
- refs/heads/jb/staxperiments-wip
- refs/heads/jb/tagging
- refs/heads/jet
- refs/heads/jet2
- refs/heads/kokoro-fix
- refs/heads/layer-scan-remat
- refs/heads/libtpu_import_fix
- refs/heads/libtpu_install
- refs/heads/linearize
- refs/heads/lint-action-status
- refs/heads/log-resharding
- refs/heads/main
- refs/heads/make-custom-vjp-bwd-nones-more-robust
- refs/heads/masking-revisions
- refs/heads/maxfail
- refs/heads/mlir
- refs/heads/mutable-array-scan
- refs/heads/nanobind
- refs/heads/nightly
- refs/heads/no-more-post-process
- refs/heads/on_demand_ci_jax_gpu
- refs/heads/pargmax-primitive
- refs/heads/pdot
- refs/heads/pjrt_c_api
- refs/heads/pjrt_c_api_tests
- refs/heads/pjrt_tpu
- refs/heads/pmap
- refs/heads/precision-flag-name-updating
- refs/heads/pretty-print-improvements
- refs/heads/primitives
- refs/heads/print-partial-eval-stats
- refs/heads/print_refactor
- refs/heads/prngkey-linearity-brainstorming
- refs/heads/profiler_test
- refs/heads/pullrequest
- refs/heads/rafi
- refs/heads/random-docstring-fix
- refs/heads/ray
- refs/heads/refs-in-vjps
- refs/heads/rejames3
- refs/heads/restore_jax_remat_opt_barrier_4.2.0
- refs/heads/revive-leak-checker
- refs/heads/rvjp
- refs/heads/sc
- refs/heads/setupcfg
- refs/heads/sharadmv-patch-1
- refs/heads/sharadmv-patch-2
- refs/heads/sharding_docstrings
- refs/heads/shoyer-jit-stuff
- refs/heads/side-effects
- refs/heads/skye-patch-1
- refs/heads/skye-patch-2
- refs/heads/skye_tpu_ci_testing
- refs/heads/software-pipeline
- refs/heads/solvers
- refs/heads/source-line-info-experiments
- refs/heads/speed-up-hips
- refs/heads/stax2k
- refs/heads/staxperiments
- refs/heads/test_336130333
- refs/heads/test_351557292
- refs/heads/test_355704839
- refs/heads/test_357720302
- refs/heads/test_363455815
- refs/heads/test_372280020
- refs/heads/test_377534845
- refs/heads/test_381553014
- refs/heads/test_388304561
- refs/heads/test_388307100
- refs/heads/test_390018950
- refs/heads/test_394110078
- refs/heads/test_395690233
- refs/heads/test_397174680
- refs/heads/test_398845683
- refs/heads/test_401921218
- refs/heads/test_404746466
- refs/heads/test_407577871
- refs/heads/test_409589575
- refs/heads/test_410618045
- refs/heads/test_410822358
- refs/heads/test_412110232
- refs/heads/test_415522599
- refs/heads/test_417636490
- refs/heads/test_418424182
- refs/heads/test_419001739
- refs/heads/test_419670164
- refs/heads/test_420040086
- refs/heads/test_423358730
- refs/heads/test_423914979
- refs/heads/test_424101639
- refs/heads/test_428105412
- refs/heads/test_428681711
- refs/heads/test_429097675
- refs/heads/test_429404946
- refs/heads/test_429407854
- refs/heads/test_430363082
- refs/heads/test_431602697
- refs/heads/test_432560862
- refs/heads/test_433530047
- refs/heads/test_435398406
- refs/heads/test_435759524
- refs/heads/test_436391492
- refs/heads/test_437633657
- refs/heads/test_438033664
- refs/heads/test_440450701
- refs/heads/test_440927765
- refs/heads/test_441041293
- refs/heads/test_442162369
- refs/heads/test_442573037
- refs/heads/test_444674770
- refs/heads/test_445478827
- refs/heads/test_446530978
- refs/heads/test_447366636
- refs/heads/test_448898047
- refs/heads/test_449567602
- refs/heads/test_452087344
- refs/heads/test_454250510
- refs/heads/test_454708810
- refs/heads/test_456562272
- refs/heads/test_456931912
- refs/heads/test_457007257
- refs/heads/test_458532437
- refs/heads/test_459875360
- refs/heads/test_460144999
- refs/heads/test_460216807
- refs/heads/test_460531663
- refs/heads/test_460834259
- refs/heads/test_461075508
- refs/heads/test_461210795
- refs/heads/test_462219572
- refs/heads/test_464830257
- refs/heads/test_464993276
- refs/heads/test_465655514
- refs/heads/test_466488206
- refs/heads/test_467774521
- refs/heads/test_468028234
- refs/heads/test_469397862
- refs/heads/test_470094033
- refs/heads/test_471352829
- refs/heads/test_471795639
- refs/heads/test_472517963
- refs/heads/test_477322238
- refs/heads/test_478456770
- refs/heads/test_480202723
- refs/heads/test_480224686
- refs/heads/test_482313959
- refs/heads/test_482602398
- refs/heads/test_482900882
- refs/heads/test_482900887
- refs/heads/test_482900889
- refs/heads/test_482900892
- refs/heads/test_482966582
- refs/heads/test_484448150
- refs/heads/test_487929724
- refs/heads/test_489763439
- refs/heads/test_490629581
- refs/heads/test_490662809
- refs/heads/test_493164942
- refs/heads/test_493596333
- refs/heads/test_493952215
- refs/heads/test_496483150
- refs/heads/test_496643067
- refs/heads/test_497091885
- refs/heads/test_497225669
- refs/heads/test_497604432
- refs/heads/test_497854051
- refs/heads/test_498025158
- refs/heads/test_498264649
- refs/heads/test_498270669
- refs/heads/test_498435892
- refs/heads/test_500618944
- refs/heads/test_501135700
- refs/heads/test_501345828
- refs/heads/test_502156425
- refs/heads/test_502769787
- refs/heads/test_502853874
- refs/heads/test_503474380
- refs/heads/test_505159976
- refs/heads/test_505998738
- refs/heads/test_507395004
- refs/heads/test_508168519
- refs/heads/test_508168521
- refs/heads/test_508168522
- refs/heads/test_508464286
- refs/heads/test_508726173
- refs/heads/test_509261214
- refs/heads/test_509637208
- refs/heads/test_512183375
- refs/heads/test_512192439
- refs/heads/test_514760173
- refs/heads/test_515135429
- refs/heads/test_515216183
- refs/heads/test_515494408
- refs/heads/test_516810517
- refs/heads/test_516810518
- refs/heads/test_516810520
- refs/heads/test_516810523
- refs/heads/test_517283818
- refs/heads/test_517772432
- refs/heads/test_518102665
- refs/heads/test_518506421
- refs/heads/test_518850240
- refs/heads/test_518859003
- refs/heads/test_518870432
- refs/heads/test_520744762
- refs/heads/test_522180444
- refs/heads/test_522435255
- refs/heads/test_523544053
- refs/heads/test_526822983
- refs/heads/test_527801553
- refs/heads/test_530376469
- refs/heads/test_530618498
- refs/heads/test_533228031
- refs/heads/test_536873968
- refs/heads/test_539130561
- refs/heads/test_541092273
- refs/heads/test_541402440
- refs/heads/test_541402556
- refs/heads/test_542362963
- refs/heads/test_545420249
- refs/heads/test_545655471
- refs/heads/test_547626617
- refs/heads/test_547626618
- refs/heads/test_547906057
- refs/heads/test_547924761
- refs/heads/test_548696078
- refs/heads/test_548696079
- refs/heads/test_549523682
- refs/heads/test_551644794
- refs/heads/test_552737402
- refs/heads/test_552949448
- refs/heads/test_552950630
- refs/heads/test_553240003
- refs/heads/test_553454659
- refs/heads/test_554489856
- refs/heads/test_554506940
- refs/heads/test_554515251
- refs/heads/test_555009598
- refs/heads/test_555362945
- refs/heads/test_557405952
- refs/heads/test_557606593
- refs/heads/test_557639059
- refs/heads/test_559223667
- refs/heads/test_559235820
- refs/heads/test_559282859
- refs/heads/test_559537321
- refs/heads/test_559832084
- refs/heads/test_560102989
- refs/heads/test_562875610
- refs/heads/test_563106152
- refs/heads/test_563852473
- refs/heads/test_564778874
- refs/heads/test_564832306
- refs/heads/test_565394366
- refs/heads/test_565598091
- refs/heads/test_566338492
- refs/heads/test_566829535
- refs/heads/test_567743154
- refs/heads/test_567883345
- refs/heads/test_568984719
- refs/heads/test_569230764
- refs/heads/test_570062727
- refs/heads/test_570610183
- refs/heads/test_573316975
- refs/heads/test_573862211
- refs/heads/test_573943286
- refs/heads/test_575248203
- refs/heads/test_577461132
- refs/heads/test_578841721
- refs/heads/test_578847678
- refs/heads/test_578946432
- refs/heads/test_579940420
- refs/heads/test_581954801
- refs/heads/test_584888085
- refs/heads/test_587518619
- refs/heads/test_587806540
- refs/heads/test_588037215
- refs/heads/test_588041819
- refs/heads/test_589017719
- refs/heads/test_589018175
- refs/heads/test_589025336
- refs/heads/test_591066347
- refs/heads/test_591416564
- refs/heads/test_592754615
- refs/heads/test_592998231
- refs/heads/test_595845422
- refs/heads/test_595886400
- refs/heads/test_595897224
- refs/heads/test_597341374
- refs/heads/test_597350925
- refs/heads/test_597796534
- refs/heads/test_598509568
- refs/heads/test_606813605
- refs/heads/test_607616480
- refs/heads/test_607748347
- refs/heads/test_607752766
- refs/heads/test_607900615
- refs/heads/test_608638040
- refs/heads/test_610593276
- refs/heads/test_611694090
- refs/heads/test_611798938
- refs/heads/test_613248331
- refs/heads/test_613354286
- refs/heads/test_613369789
- refs/heads/test_613383623
- refs/heads/test_613418799
- refs/heads/test_613719003
- refs/heads/test_614048400
- refs/heads/test_614361032
- refs/heads/test_614378644
- refs/heads/test_614684199
- refs/heads/test_616247522
- refs/heads/test_616865795
- refs/heads/test_616869664
- refs/heads/test_616889174
- refs/heads/test_617800884
- refs/heads/test_617824989
- refs/heads/test_617862518
- refs/heads/test_618058831
- refs/heads/test_618721726
- refs/heads/test_618824221
- refs/heads/test_619414418
- refs/heads/test_619787143
- refs/heads/test_620111786
- refs/heads/test_620893492
- refs/heads/test_622817772
- refs/heads/test_622843729
- refs/heads/test_626392371
- refs/heads/test_627162247
- refs/heads/test_628066252
- refs/heads/test_628538605
- refs/heads/test_629379237
- refs/heads/test_629497642
- refs/heads/test_629530367
- refs/heads/test_630342579
- refs/heads/test_631169957
- refs/heads/test_631422730
- refs/heads/test_631488763
- refs/heads/test_631594796
- refs/heads/test_632221941
- refs/heads/test_632526631
- refs/heads/test_632583965
- refs/heads/test_632679634
- refs/heads/test_633190586
- refs/heads/test_633711385
- refs/heads/test_634046352
- refs/heads/test_634071280
- refs/heads/test_634465693
- refs/heads/test_634532880
- refs/heads/test_634604111
- refs/heads/test_634604461
- refs/heads/test_635414459
- refs/heads/test_635474730
- refs/heads/test_635560292
- refs/heads/test_635567131
- refs/heads/test_635618263
- refs/heads/test_635691001
- refs/heads/test_635818597
- refs/heads/test_635965009
- refs/heads/test_635986104
- refs/heads/test_636053242
- refs/heads/test_636160192
- refs/heads/test_636172317
- refs/heads/test_fix
- refs/heads/test_no_log_spam
- refs/heads/test_timeout_tpu
- refs/heads/timeout
- refs/heads/tpu-ext-macos
- refs/heads/tpu_build_from_head
- refs/heads/tpu_ci_disable_tcmalloc
- refs/heads/tpu_ci_fix_compat_build
- refs/heads/tpu_ci_pjrt
- refs/heads/tpu_ci_remove_workaround
- refs/heads/tpu_ci_strings
- refs/heads/tpu_py_version
- refs/heads/tpu_py_version2
- refs/heads/tweaks
- refs/heads/typo
- refs/heads/undo-tree
- refs/heads/update-pypi
- refs/heads/variance_scaling_axes
- refs/heads/yashk2810-patch-1
- refs/heads/yashk2810-patch-10
- refs/heads/yashk2810-patch-11
- refs/heads/yashk2810-patch-12
- refs/heads/yashk2810-patch-13
- refs/heads/yashk2810-patch-14
- refs/heads/yashk2810-patch-15
- refs/heads/yashk2810-patch-16
- refs/heads/yashk2810-patch-17
- refs/heads/yashk2810-patch-18
- refs/heads/yashk2810-patch-19
- refs/heads/yashk2810-patch-2
- refs/heads/yashk2810-patch-20
- refs/heads/yashk2810-patch-3
- refs/heads/yashk2810-patch-4
- refs/heads/yashk2810-patch-5
- refs/heads/yashk2810-patch-6
- refs/heads/yashk2810-patch-7
- refs/heads/yashk2810-patch-8
- refs/heads/yashk2810-patch-9
- refs/remotes/upstream/rename-mastertrace
- refs/tags/0.3.20
- refs/tags/0.3.21
- refs/tags/0.3.22
- refs/tags/0.3.23
- refs/tags/0.3.24
- refs/tags/0.3.25
- refs/tags/jax-v0.1.49
- refs/tags/jax-v0.1.55
- refs/tags/jax-v0.1.58
- refs/tags/jax-v0.1.70
- refs/tags/jax-v0.2.20
- refs/tags/jax-v0.2.21
- refs/tags/jax-v0.2.23
- refs/tags/jax-v0.2.24
- refs/tags/jax-v0.2.26
- refs/tags/jax-v0.3.0
- refs/tags/jax-v0.3.1
- refs/tags/jax-v0.3.10
- refs/tags/jax-v0.3.11
- refs/tags/jax-v0.3.12
- refs/tags/jax-v0.3.21-rc
- refs/tags/jax-v0.3.6
- refs/tags/jax-v0.3.9
- refs/tags/jaxlib-v0.1.32
- refs/tags/jaxlib-v0.1.33
- refs/tags/jaxlib-v0.1.34
- refs/tags/jaxlib-v0.1.35
- refs/tags/jaxlib-v0.1.36
- refs/tags/jaxlib-v0.1.37
- refs/tags/jaxlib-v0.1.38
- refs/tags/jaxlib-v0.1.39
- refs/tags/jaxlib-v0.1.40
- refs/tags/jaxlib-v0.1.42
- refs/tags/jaxlib-v0.1.43
- refs/tags/jaxlib-v0.1.44
- refs/tags/jaxlib-v0.1.45
- refs/tags/jaxlib-v0.1.46
- refs/tags/jaxlib-v0.1.47
- refs/tags/jaxlib-v0.1.50
- refs/tags/jaxlib-v0.1.51
- refs/tags/jaxlib-v0.1.55
- refs/tags/jaxlib-v0.1.70
- refs/tags/jaxlib-v0.1.71
- refs/tags/jaxlib-v0.1.72
- refs/tags/jaxlib-v0.1.73
- refs/tags/jaxlib-v0.1.75
- refs/tags/jaxlib-v0.1.76
- refs/tags/jaxlib-v0.3.0
- refs/tags/jaxlib-v0.3.10
- refs/tags/test-jaxlib
- refs/tags/workflow_test
- test-docs
- jaxlib-v0.4.9
- jaxlib-v0.4.7
- jaxlib-v0.4.6
- jaxlib-v0.4.4
- jaxlib-v0.4.3
- jaxlib-v0.4.28
- jaxlib-v0.4.27
- jaxlib-v0.4.26
- jaxlib-v0.4.25
- jaxlib-v0.4.24
- jaxlib-v0.4.23
- jaxlib-v0.4.22
- jaxlib-v0.4.21
- jaxlib-v0.4.20
- jaxlib-v0.4.2
- jaxlib-v0.4.19
- jaxlib-v0.4.18
- jaxlib-v0.4.17
- jaxlib-v0.4.16
- jaxlib-v0.4.15
- jaxlib-v0.4.14
- jaxlib-v0.4.13
- jaxlib-v0.4.12
- jaxlib-v0.4.11
- jaxlib-v0.4.10
- jaxlib-v0.4.1
- jaxlib-v0.4.0
- jaxlib-v0.3.7
- jaxlib-v0.3.5
- jaxlib-v0.3.25
- jaxlib-v0.3.24
- jaxlib-v0.3.22
- jaxlib-v0.3.20
- jaxlib-v0.3.2
- jaxlib-v0.3.15
- jaxlib-v0.3.14
- jaxlib-v0.1.74
- jaxlib-v0.1.69
- jaxlib-v0.1.68
- jaxlib-v0.1.67
- jaxlib-v0.1.66
- jaxlib-v0.1.65
- jaxlib-v0.1.64
- jaxlib-v0.1.63
- jaxlib-v0.1.62
- jaxlib-v0.1.61
- jaxlib-v0.1.60
- jaxlib-v0.1.59
- jaxlib-v0.1.58
- jaxlib-v0.1.57
- jaxlib-v0.1.56
- jaxlib-v0.1.52
- jaxlib-v0.1.49
- jaxlib-v0.1.48
- jax-v0.4.9-rc
- jax-v0.4.9
- jax-v0.4.8
- jax-v0.4.7-rc1
- jax-v0.4.7-rc
- jax-v0.4.7
- jax-v0.4.6-rc
- jax-v0.4.6
- jax-v0.4.5
- jax-v0.4.4-rc
- jax-v0.4.4
- jax-v0.4.3-rc
- jax-v0.4.3
- jax-v0.4.28-rc
- jax-v0.4.28
- jax-v0.4.27-rc
- jax-v0.4.27
- jax-v0.4.26-rc
- jax-v0.4.26
- jax-v0.4.25-rc
- jax-v0.4.25
- jax-v0.4.24-rc
- jax-v0.4.24
- jax-v0.4.23-rc
- jax-v0.4.23
- jax-v0.4.22-rc2
- jax-v0.4.22-rc
- jax-v0.4.22
- jax-v0.4.21-rc
- jax-v0.4.21
- jax-v0.4.20-rc
- jax-v0.4.20
- jax-v0.4.2-rc
- jax-v0.4.2
- jax-v0.4.19-rc
- jax-v0.4.19
- jax-v0.4.18-rc
- jax-v0.4.18
- jax-v0.4.17-rc
- jax-v0.4.17
- jax-v0.4.16-rc1
- jax-v0.4.16-rc
- jax-v0.4.16
- jax-v0.4.15-rc
- jax-v0.4.15
- jax-v0.4.14-rc
- jax-v0.4.14
- jax-v0.4.13-rc
- jax-v0.4.13
- jax-v0.4.12-rc
- jax-v0.4.12
- jax-v0.4.11-rc
- jax-v0.4.11
- jax-v0.4.10-rc
- jax-v0.4.10
- jax-v0.4.1-rc
- jax-v0.4.1
- jax-v0.4.0-rc1
- jax-v0.4.0-rc
- jax-v0.4.0
- jax-v0.3.8
- jax-v0.3.7
- jax-v0.3.5
- jax-v0.3.4
- jax-v0.3.3
- jax-v0.3.25-rc3
- jax-v0.3.25-rc2
- jax-v0.3.25-rc1
- jax-v0.3.25-rc
- jax-v0.3.25
- jax-v0.3.24-rc1
- jax-v0.3.24-rc
- jax-v0.3.24
- jax-v0.3.23-rc
- jax-v0.3.23
- jax-v0.3.22-rc
- jax-v0.3.22
- jax-v0.3.21
- jax-v0.3.20-rc
- jax-v0.3.20
- jax-v0.3.2
- jax-v0.3.19
- jax-v0.3.18-rc
- jax-v0.3.18
- jax-v0.3.17
- jax-v0.3.16
- jax-v0.3.15-rc
- jax-v0.3.15
- jax-v0.3.14
- jax-v0.3.13
- jax-v0.2.9
- jax-v0.2.8
- jax-v0.2.7
- jax-v0.2.6
- jax-v0.2.5
- jax-v0.2.4
- jax-v0.2.3
- jax-v0.2.28
- jax-v0.2.27
- jax-v0.2.25
- jax-v0.2.22
- jax-v0.2.2
- jax-v0.2.19
- jax-v0.2.18
- jax-v0.2.17
- jax-v0.2.16
- jax-v0.2.15
- jax-v0.2.14
- jax-v0.2.13
- jax-v0.2.12
- jax-v0.2.11
- jax-v0.2.10
- jax-v0.2.1
- jax-v0.2.0
- jax-v0.1.77
- jax-v0.1.76
- jax-v0.1.75
- jax-v0.1.74
- jax-v0.1.73
- jax-v0.1.72
- jax-v0.1.71
- jax-v0.1.69
- jax-v0.1.68
- jax-v0.1.67
- jax-v0.1.66
- jax-v0.1.65
- jax-v0.1.64
- jax-v0.1.63
- jax-v0.1.62
- jax-v0.1.61
- jax-v0.1.60
- jax-v0.1.59
Raw File
Take a new snapshot of a software origin
If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.
Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.![swh spinner](/static/img/swh-spinner.gif)
Processing "take a new snapshot" request ...
Permalinks
To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.
Tip revision: 58e46b48e6641d72f8a236485b4af6d39a497351 authored by Yash Katariya on 16 February 2023, 16:35:21 UTC
Prepare for jax and jaxlib 0.4.4 release
Prepare for jax and jaxlib 0.4.4 release
Tip revision: 58e46b4
mnist_classifier.py
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
The mini-library jax.example_libraries.stax is for neural network building, and
the mini-library jax.example_libraries.optimizers is for first-order stochastic
optimization.
"""
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
if __name__ == "__main__":
rng = random.PRNGKey(0)
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")