https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
535eb08 Flax layers are parametrizable with custom conv_general_dilated. PiperOrigin-RevId: 509637208 14 February 2023, 22:27:28 UTC
59e9746 Fix issue where HLO could not be generated for custom gradient. It appears that the custom gradient function must be traced in the same context as the context in which it was defined. Fixed by shuffling around the default graphs. PiperOrigin-RevId: 509618802 14 February 2023, 21:22:30 UTC
a9ef989 Merge pull request #14472 from nouiz:shmap_jep_fixes PiperOrigin-RevId: 509617771 14 February 2023, 21:14:33 UTC
5860cfd Merge pull request #14453 from jakevdp:dtypes-doc PiperOrigin-RevId: 509610755 14 February 2023, 20:48:40 UTC
33bed1e Opt into higher matmul precision for A100 and TPU tests. PiperOrigin-RevId: 509598465 14 February 2023, 20:03:12 UTC
aa98c99 Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit PiperOrigin-RevId: 509585801 14 February 2023, 19:24:55 UTC
93c9313 Use right fct name. 14 February 2023, 19:21:16 UTC
d2bb1e0 Be consistent in the index used 14 February 2023, 19:21:03 UTC
11e3219 DOC: add docs for jax.dtypes module 14 February 2023, 19:18:59 UTC
47dca67 Merge pull request #14456 from jakevdp:jax-typing-public PiperOrigin-RevId: 509585352 14 February 2023, 19:17:05 UTC
6735102 Small crash fixes 14 February 2023, 19:14:26 UTC
658a934 License cleanup. PiperOrigin-RevId: 509563977 14 February 2023, 18:07:02 UTC
1c651f2 Catch the NaN's and raise a better error message when jax_debug_nans flag is True. PiperOrigin-RevId: 509552717 14 February 2023, 17:27:36 UTC
995ef40 [JAX] Improve error message when jit tracer passed to a shape. Adds additional debugging message to the shape explaining why the value is a tracer. Fixes #14279 PiperOrigin-RevId: 509545985 14 February 2023, 17:13:01 UTC
4237f22 clean up no-need exception raise. PiperOrigin-RevId: 509545493 14 February 2023, 17:04:19 UTC
15196bc [sparse] enable bcsr_dot_general cusparse lowering PiperOrigin-RevId: 509537223 14 February 2023, 16:32:04 UTC
582c042 Implement lowering for convolutions with dynamic padding PiperOrigin-RevId: 509451627 14 February 2023, 08:55:45 UTC
0f7ffb8 Bump minimum Python version in "contributing" docs. PiperOrigin-RevId: 509385522 14 February 2023, 02:05:11 UTC
9e01ee4 Merge pull request #14457 from mattjj:djax-bug-fix PiperOrigin-RevId: 509377741 14 February 2023, 01:28:37 UTC
442aa02 Fix xmap staging rule to handle positional semantics PiperOrigin-RevId: 509356614 14 February 2023, 00:05:17 UTC
e1ff0c1 Make colab_gpu.ipynb compatible with newer JAX versions PiperOrigin-RevId: 509356393 13 February 2023, 23:56:58 UTC
7975192 Expose jax.typing & update docs 13 February 2023, 23:53:08 UTC
96c558d fix minor broadcasting bug Co-authored-by: Adam Paszke <apaszke@google.com> 13 February 2023, 23:13:13 UTC
d0eedf7 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr PiperOrigin-RevId: 509341618 13 February 2023, 22:58:20 UTC
4a523e3 Minimize exported names from jax.experimental.maps. Move implementation of maps to jax._src.maps. PiperOrigin-RevId: 509309092 13 February 2023, 20:57:54 UTC
2fc64be Change the `axis_resources` argument of `with_sharding_constraint` to `shardings` to match `pjit` and `jit`. PiperOrigin-RevId: 509275107 13 February 2023, 18:53:57 UTC
c49af18 Merge pull request #14365 from jakevdp:reducers-initial PiperOrigin-RevId: 509253981 13 February 2023, 17:43:46 UTC
83b7ba2 Merge pull request #14444 from jakevdp:fix-csr-lowering PiperOrigin-RevId: 509241948 13 February 2023, 16:59:37 UTC
58323d5 jax.numpy reductions: better validation of initial value 13 February 2023, 16:43:25 UTC
ddae1d0 fix change to csr lowering rule 13 February 2023, 16:39:05 UTC
2d47921 [shape_poly] Fix some tests for shape pol with native lowering Native lowering requires that dimension variables be computable from the shapes of the arguments that are kept by lowering. We had some tests that were using dimension variables but were not using the actual inputs. Then lowering removes the inputs and it is not possible anymore to recover the values of the dimension variables at invocation time. Here we primarily changed the tests to ensure they use not just the shape of the input but the actual value. In some cases we have disabled some testing, until https://github.com/google/jax/issues/14437 is fixed PiperOrigin-RevId: 509171805 13 February 2023, 10:48:11 UTC
6b70728 Fix sharding_test PiperOrigin-RevId: 509048406 12 February 2023, 18:26:17 UTC
002cdf6 Merge pull request #14432 from gnecula:call_tf_checks_2 PiperOrigin-RevId: 509043710 12 February 2023, 17:34:50 UTC
f280d31 [jax2tf] Minor fix: remove dedundant check 12 February 2023, 15:02:53 UTC
6caaffc Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources. PiperOrigin-RevId: 508934327 11 February 2023, 23:30:14 UTC
bc1f5f1 Register missing standard primitives to shard_map.py. PiperOrigin-RevId: 508920824 11 February 2023, 20:58:22 UTC
654c1d3 Fix _standard_rep_rule in shard_map.py when in_rep is empty. set.intersection() with no arguments (in_rep is empty) raises an exception. PiperOrigin-RevId: 508910287 11 February 2023, 19:09:47 UTC
1089992 Merge pull request #14424 from stellaraccident:devenh PiperOrigin-RevId: 508902521 11 February 2023, 17:50:50 UTC
612a940 Minimize the set of names exported from jax.experimental.pjit. PiperOrigin-RevId: 508889911 11 February 2023, 15:37:32 UTC
9316188 [Rollback] Convert _arrays to return PyArray instead of PyBuffer. PiperOrigin-RevId: 508827908 11 February 2023, 05:36:56 UTC
c1e13bd A few developer workflow enhancements for working with jaxlib. It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow: * Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in. * Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc). * Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing). 11 February 2023, 05:03:21 UTC
61da781 [JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array. PiperOrigin-RevId: 508822404 11 February 2023, 04:56:34 UTC
1bdcd5e Merge pull request #14415 from jakevdp:bcsr-matmul PiperOrigin-RevId: 508785095 11 February 2023, 00:55:05 UTC
a262314 prune unintended exports from `jax.interpreters.batching` PiperOrigin-RevId: 508784928 11 February 2023, 00:47:28 UTC
e548585 Add back loading TPU plugin for older jaxlib versions. This was removed in https://github.com/google/jax/commit/668b82d529e2649f1dbf7cdf9ec8d934fda09a19. PiperOrigin-RevId: 508777939 11 February 2023, 00:16:20 UTC
26ddf3b Merge pull request #14419 from jakevdp:spsolve-cpu-lowering PiperOrigin-RevId: 508777573 11 February 2023, 00:16:05 UTC
de8a77a [sparse] implement BCSR.__matmul__ 11 February 2023, 00:11:57 UTC
fc507f2 Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples PiperOrigin-RevId: 508777043 11 February 2023, 00:08:32 UTC
0d07372 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones. PiperOrigin-RevId: 508770290 10 February 2023, 23:40:25 UTC
552fc2c [sparse] add CPU lowering rule for sparse.linalg.spsolve 10 February 2023, 23:35:42 UTC
568a93b Convert _arrays to return PyArray instead of PyBuffer. PiperOrigin-RevId: 508769390 10 February 2023, 23:32:57 UTC
9538bc3 generalize vmap spmd_axis_name to accept tuples of axis names This brings the argument more in line with what can appear as positional arguments to the PartitionSpec constructor. 10 February 2023, 23:25:23 UTC
d531ec9 Merge pull request #14331 from canyon289:add_user_guide_sentence PiperOrigin-RevId: 508760750 10 February 2023, 22:53:17 UTC
2f80e46 [XLA:Python] Fix overly pessimistic handling of singleton dimensions in dlpack code. Requires an accompanying jaxlib change. Fixes https://github.com/google/jax/issues/14399 PiperOrigin-RevId: 508757315 10 February 2023, 22:44:22 UTC
dc6bf9b Merge pull request #14408 from lucashofer:scipy_spence PiperOrigin-RevId: 508756972 10 February 2023, 22:36:15 UTC
4deb12e Merge pull request #14411 from hawkinsp:kepler PiperOrigin-RevId: 508755353 10 February 2023, 22:28:32 UTC
18b251e Add notes and headeers to user guides 10 February 2023, 22:17:15 UTC
1526c3e Improve the error message which is raised from `_get_and_check_device_assignment`. Before: ``` ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU ``` After: ``` ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp) ``` PiperOrigin-RevId: 508746961 10 February 2023, 21:54:15 UTC
4636276 added scipy special spence added dtype to arrays in the _spence_poly function 10 February 2023, 20:33:47 UTC
57900d7 Merge pull request #14364 from jakevdp:fix-tril-indices PiperOrigin-RevId: 508723970 10 February 2023, 20:25:06 UTC
6ee6763 Split PyTorch interoperability tests into their own test. PiperOrigin-RevId: 508722180 10 February 2023, 20:17:11 UTC
5da5967 Merge pull request #14395 from jakevdp:bcsr-dot-general PiperOrigin-RevId: 508721790 10 February 2023, 20:09:29 UTC
ac647b9 [sparse] implement autodiff rules for bcsr_dot_general 10 February 2023, 20:00:30 UTC
ec56d71 Drop support for NVIDIA Kepler series GPUs in jaxlib builds. 10 February 2023, 19:15:15 UTC
7a864d7 Merge pull request #14394 from jakevdp:jax-array-methods PiperOrigin-RevId: 508694486 10 February 2023, 18:27:14 UTC
be21404 [jax2tf] Add shard_map tests Also fix tests to run on multiple devices in TF PiperOrigin-RevId: 508691872 10 February 2023, 18:18:19 UTC
d09f3c2 Merge pull request #11727 from gnecula:call_tf_checks PiperOrigin-RevId: 508685246 10 February 2023, 17:51:35 UTC
60256df [typing] define additional methods & properties on jax.Array These are the methods that are only valid for actual materialized arrays (i.e. not Tracers) In order to simplify the experience for users, we want to maintain only a single jax.Array type, so we define all methods here and raise explicit errors on Tracer instances. 10 February 2023, 17:42:32 UTC
7659a3a Enable call_tf_native_lowering_test. PiperOrigin-RevId: 508677359 10 February 2023, 17:16:53 UTC
9f0783f Merge pull request #14403 from gnecula:reduce_precision PiperOrigin-RevId: 508635187 10 February 2023, 13:38:59 UTC
f070557 Merge pull request #14400 from gnecula:native_bug1 PiperOrigin-RevId: 508635169 10 February 2023, 13:30:24 UTC
30fda87 [call_tf] Improve error reporting Add more checks to catch early the cases when the called TF function returns values that are not convertible to JAX values (arrays of numeric values). All these cases were resulting in errors even before but sometimes these errors were deep in the stack and harder to diagnose. 10 February 2023, 13:19:49 UTC
48c2538 [jax2tf] Add support for reduce_precision 10 February 2023, 12:29:46 UTC
ff6051f [shape_poly] Better error message for functions that do not use input arguments Also: * fixed some of the tests that were using the shape but not the value of the input arguments * fix importing of mlir.py due to recent move of interpreters.mlir to _src.interpreters.mlir 10 February 2023, 09:59:46 UTC
82e5767 update hsplit and testHVDSplit for 1D array 10 February 2023, 07:27:37 UTC
54ff78d Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray. PiperOrigin-RevId: 508502470 10 February 2023, 00:11:48 UTC
357b48d Merge pull request #14391 from ROCmSoftwarePlatform:rocm_switch_to_rocm54 PiperOrigin-RevId: 508497281 09 February 2023, 23:50:49 UTC
1c84e4a migrate internal dependencies from `jax.interpreters.batching` to `jax._src.interpreters.batching` ... in preparation for paring down `jax.interpreters.batching`'s exported symbols. PiperOrigin-RevId: 508487887 09 February 2023, 23:11:57 UTC
12dc73d Merge pull request #14388 from jakevdp:bcsr-todense-ad PiperOrigin-RevId: 508477843 09 February 2023, 22:41:41 UTC
668b82d [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS. Loading TPU PJRT plugin is moved to make_tpu_client. This change is based on https://github.com/google/jax/pull/14011. PiperOrigin-RevId: 508477737 09 February 2023, 22:33:46 UTC
7651866 [sparse] implement autodiff rules for bcsr primitives 09 February 2023, 22:19:22 UTC
15c9bca [sparse] add cusparse lowering for simplest cases of bcsr_dot_general PiperOrigin-RevId: 508473938 09 February 2023, 22:18:44 UTC
253cd4d Merge pull request #14387 from ROCmSoftwarePlatform:rocm_reenable_dirichlet_test PiperOrigin-RevId: 508466026 09 February 2023, 21:50:06 UTC
88cc254 [JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array. PiperOrigin-RevId: 508463147 09 February 2023, 21:39:41 UTC
0c14e9a Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules. Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users. Fix some overly-strict type annotations. PiperOrigin-RevId: 508461199 09 February 2023, 21:31:40 UTC
adcceb2 Merge pull request #14384 from mattjj:pjit-pretty-print PiperOrigin-RevId: 508454299 09 February 2023, 21:04:58 UTC
a964dc3 simpler pretty-print for pjit, tweak custom pp rule signature 09 February 2023, 20:45:51 UTC
7d0d9b7 [ROCm]: Re-enable Dirichlet Tests on ROCm 09 February 2023, 20:19:07 UTC
023226e [ROCm]: Move dockerfile to ROCm5.4 09 February 2023, 20:08:35 UTC
8268cd5 Add infrastructure for managing deprecations. Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh. PiperOrigin-RevId: 508349776 09 February 2023, 13:48:40 UTC
3f8cb0a Merge pull request #14379 from mattjj:shmap-vmap-spmd-axis-name PiperOrigin-RevId: 508292029 09 February 2023, 08:14:30 UTC
6fb3ace [shard-map] add vmap spmd_axis_name support, fix vmap rule bug 09 February 2023, 07:54:28 UTC
bd7c227 Merge pull request #14373 from mattjj:shmap-check-rep-false PiperOrigin-RevId: 508219490 09 February 2023, 00:49:29 UTC
1a03f34 [shard-map] if check_rep=False, don't call rep rules in eager 08 February 2023, 23:42:35 UTC
ccb974a Merge pull request #14370 from jakevdp:argpartition-impl PiperOrigin-RevId: 508194466 08 February 2023, 23:10:50 UTC
a28b012 Move contents of jax.monitoring to jax._src.monitoring. PiperOrigin-RevId: 508191560 08 February 2023, 23:03:22 UTC
7350f00 Remove `jax_experimental_subjaxpr_lowering_cache` since it was only for `jit` and was `False` by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free. PiperOrigin-RevId: 508191408 08 February 2023, 22:55:56 UTC
4fbaee5 Implement jax.numpy.argpartition 08 February 2023, 22:41:39 UTC
cc8d7fa Move jax.interpreters.mlir to jax._src.interpreters.mlir. Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally. PiperOrigin-RevId: 508187063 08 February 2023, 22:39:01 UTC
3e349c7 Merge pull request #14361 from jakevdp:doc-topk PiperOrigin-RevId: 508181335 08 February 2023, 22:19:01 UTC
back to top