https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
b552c15 [Pallas] Add temporary skip for pallas out-of-bounds interpret mode test on GPU. PiperOrigin-RevId: 634046352 15 May 2024, 20:51:37 UTC
6a949ae Merge pull request #21251 from jakevdp:fix-pmap-warning PiperOrigin-RevId: 634061434 15 May 2024, 20:43:18 UTC
f5a59cc Fix warning filter in test_pmap_of_prng_key 15 May 2024, 20:20:05 UTC
1538197 Merge pull request #21247 from dfm:gh20481 PiperOrigin-RevId: 634045828 15 May 2024, 19:53:19 UTC
bb40075 [Mosaic GPU] Fix argument checking in case the input shape is not a tuple. PiperOrigin-RevId: 634044759 15 May 2024, 19:49:27 UTC
09a4b38 Add informative error for invalid unroll in scan As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an uninformative `ZeroDivisionError`. This PR adds a check which raises a `ValueError` for `unroll<=0`. 15 May 2024, 19:40:27 UTC
4ccac4c Merge pull request #21237 from gnecula:clean_tokens PiperOrigin-RevId: 634003364 15 May 2024, 17:40:58 UTC
11a4513 [Mosaic GPU] Explicitly check the smem size so that we get a good error rather than a cryptic cuda error PiperOrigin-RevId: 633993345 15 May 2024, 17:10:46 UTC
874a20a [Mosaic GPU] Fix np.random API mistake in tests PiperOrigin-RevId: 633970332 15 May 2024, 15:53:03 UTC
7653db8 Fix CI test failure in line_search_test. A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not. PiperOrigin-RevId: 633949879 15 May 2024, 14:38:20 UTC
41153b1 Cleanup token handling during lowering Version 0.4.27 of jaxlib is now the minimum version and it supports real stablehlo tokens as module inputs and outputs. Hence we can now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs `create_tokens` and `replace_tokens_with_dummy` (both of them are always False now). We also remove `num_output_tokens` that is not used. 15 May 2024, 12:54:46 UTC
66a92c4 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678 PiperOrigin-RevId: 633816901 15 May 2024, 05:41:44 UTC
b8cec53 Update XLA dependency to use revision http://github.com/openxla/xla/commit/a0f5d76e3dac2ee3293de60a79c29a46668a53b0. PiperOrigin-RevId: 633771201 15 May 2024, 01:42:27 UTC
9af0a31 Merge pull request #21222 from tempoxylophone:patch-1 PiperOrigin-RevId: 633734390 14 May 2024, 23:05:17 UTC
48b9018 Merge pull request #21221 from nouiz:doc PiperOrigin-RevId: 633734299 14 May 2024, 23:04:56 UTC
1dff5b3 Merge pull request #21226 from Micky774:linalg-depr PiperOrigin-RevId: 633734284 14 May 2024, 23:01:50 UTC
d4f36e2 Merge pull request #21227 from jakevdp:eye-offset PiperOrigin-RevId: 633730820 14 May 2024, 22:48:53 UTC
e8b06cc Cholesky rank-1 update kernel for JAX. PiperOrigin-RevId: 633722940 14 May 2024, 22:21:38 UTC
e2918ca Added a very rough sketch of Mosaic GPU lowering for Pallas Almost nothing is supported, including * PyTree inputs/outputs * indexers * non-trivial grids * block specs * any primitives beyond the ones added here * etc etc PiperOrigin-RevId: 633713366 14 May 2024, 21:48:09 UTC
3024c78 jnp.eye: allow k to be dynamic 14 May 2024, 20:32:54 UTC
0ad5167 Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support) PiperOrigin-RevId: 633677477 14 May 2024, 19:54:48 UTC
5cc255b Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv 14 May 2024, 19:53:54 UTC
0501d3d Merge pull request #21223 from jakevdp:kl-div-doc PiperOrigin-RevId: 633651462 14 May 2024, 18:29:07 UTC
bb5787d Finalize deprecations of several APIs PiperOrigin-RevId: 633634215 14 May 2024, 17:40:40 UTC
1d6ffde Reverts 85e91c2be4310d9728f7bfeefef921ee4a075135 PiperOrigin-RevId: 633622856 14 May 2024, 17:07:44 UTC
c3c2393 kl_div: fix incorrect formula in doc 14 May 2024, 16:25:04 UTC
1e3a4b5 Remove now-empty jax/_src/third_party/numpy Followup to https://github.com/google/jax/pull/21119 PiperOrigin-RevId: 633603237 14 May 2024, 16:03:54 UTC
4bcad15 Fix type in advanced-autodiff.md removed `doc.g` text from the first section. 14 May 2024, 15:52:41 UTC
f578d78 Update doc with the other error that can be thrown. 14 May 2024, 15:35:01 UTC
aea6b32 Merge pull request #21216 from jakevdp:pmap-prng PiperOrigin-RevId: 633582800 14 May 2024, 14:48:17 UTC
95c4ba9 [JAX] Use first process id instead of process 0 to share multi-host data. PiperOrigin-RevId: 633526220 14 May 2024, 10:45:41 UTC
84774b3 Fix sparse dot metadata loader Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs. PiperOrigin-RevId: 633513110 14 May 2024, 10:08:18 UTC
5150cfe Fix PRNGKey handling under jit-of-pmap 14 May 2024, 02:04:22 UTC
e735a00 Merge pull request #21215 from jakevdp:setops-docs PiperOrigin-RevId: 633400796 14 May 2024, 01:57:48 UTC
98af786 Update XLA dependency to use revision http://github.com/openxla/xla/commit/2365b5180382662ee10bd3643fce45a66535f407. PiperOrigin-RevId: 633391316 14 May 2024, 01:10:33 UTC
a56eb56 Merge pull request #21211 from elfiegg:main PiperOrigin-RevId: 633386931 14 May 2024, 00:53:26 UTC
6157d8e Merge pull request #21176 from jakevdp:shard-map-api-docs PiperOrigin-RevId: 633382922 14 May 2024, 00:34:54 UTC
7ed7780 Improve docs for jax.numpy set-like operations 13 May 2024, 22:19:14 UTC
cd6e012 Enable JAX memory tests for GPUs and CPUs PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now. PiperOrigin-RevId: 633335638 13 May 2024, 21:37:37 UTC
72e9eb9 shard_map: add API docs 13 May 2024, 20:04:15 UTC
72a81e5 Readd a default lowering rule for cumsum et al. A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules PiperOrigin-RevId: 633297839 13 May 2024, 19:34:51 UTC
1e48adc [Pallas] Pad input/outputs in interpret mode to fix errors in OOB memory accesses. PiperOrigin-RevId: 633283991 13 May 2024, 18:50:21 UTC
b8ed346 Merge pull request #21119 from jakevdp:linalg-cond PiperOrigin-RevId: 633281675 13 May 2024, 18:43:24 UTC
6189a55 Merge pull request #21048 from jakevdp:np-squeeze-doc PiperOrigin-RevId: 633273019 13 May 2024, 18:18:09 UTC
9e7830d Async dispatch expensive computations on the JAX CPU backend. By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior. PiperOrigin-RevId: 633264117 13 May 2024, 17:53:09 UTC
54d4072 Populate `propagated_out_mem_kinds` inside the branch where it's needed PiperOrigin-RevId: 633262630 13 May 2024, 17:48:52 UTC
1f6d902 jnp.linalg.cond: improve implementation & docs 13 May 2024, 17:36:50 UTC
85e91c2 Merge pull request #21203 from gnecula:export_device_poly PiperOrigin-RevId: 633253709 13 May 2024, 17:23:28 UTC
7ceae95 Better documentation for several jax.numpy functions 13 May 2024, 17:09:52 UTC
98aead7 [export] Relax the check that exported modules are used with same number of devices as when exported Now we allow a module exported for 1 device and not using any sharding annotations to be called from a computation that uses multiple devices. Such exported modules can be parallelized trivially point-wise. 13 May 2024, 17:09:43 UTC
43d1916 Remove type promotion for mixed fp8 matmuls. 13 May 2024, 16:50:52 UTC
e4f3b3f Merge pull request #21169 from justinjfu/splash_precision_fix Disable bfloat16 on long seq lengths for splash attention kernel test 13 May 2024, 16:34:39 UTC
35a512d CI: update NumPy build version to 2.0.0rc2 PiperOrigin-RevId: 633233231 13 May 2024, 16:18:40 UTC
de14e3b Reverts 49bd4d6f01d6cda00f9b1bdfbda156636baae928 PiperOrigin-RevId: 633221195 13 May 2024, 15:35:40 UTC
e66a234 Merge pull request #21191 from gnecula:export_simplify PiperOrigin-RevId: 633179742 13 May 2024, 12:48:13 UTC
54ca3d4 Merge pull request #21202 from superbobry:pallas PiperOrigin-RevId: 633176367 13 May 2024, 12:30:57 UTC
1fed784 Merge pull request #20940 from piotrfilipiuk:changelist/623910451 PiperOrigin-RevId: 633170419 13 May 2024, 12:03:28 UTC
78d4d0a [export] Simplify export internals, prepare for integration with AOT APIs In preparation for a better integration of the jax.experimental.export with the AOT APIs, we make several simplifications: * turn on always the generation of shape assertions in presence of shape polymorphism. Previously, shape assertions were turned on unless the serialization version was less than 7 (possible only before March 27th, 2024 when the minimum serialization version was bumped to 9), or if the user specified explicitly that shape assertions should be turned off. It is not safe to turn off shape assertions and I am not aware of an instance where somebody had to turn them off, except for temporary debugging. We keep the `DisabledSafetyCheck.shape_assertions` API for now, for backwards compatibility, but it has no effect and it emits a deprecation warning. * remove the code that was conditional on the serialization version being less than 9, e.g., for the lowering in presence of effects. * remove a safety check that ensures that when `export` is used on JAX callables, i.e., not the result of `jax.jit`, the code should not contain non-replicated sharding annotations. This usage of `export` is rare and will be removed once `export` will be integrated with the AOT APIs. * remove code that was needed only for older jaxlib to replace_tokens_with_dummy. 13 May 2024, 11:41:51 UTC
8094d0d Guarded Pallas GPU import in tests/pallas/pallas_test.py We do not build Triton IR bindings on Windows. This should fix https://github.com/google/jax/actions/runs/9051189315/job/24867428634. 13 May 2024, 11:23:18 UTC
1c6855a Ensured that all Pallas GPU tests depend on :pallas_gpu This dependency is added implicitly by Google-internal infra, but we need it to be explicit for Bazel builds to avoid ImportErrors at lowering time. PiperOrigin-RevId: 633147268 13 May 2024, 10:07:22 UTC
ba8480a Register TPU profiler plugin when get_topology_desc is called with tpu platform. This allows the TPU profiler to work with other plugin backends. Tested on a GPU VM: $ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html $ pip install -e . $ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py Running tests under Python 3.10.12: /usr/bin/python [ RUN ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 learning/45eac/tfrc/runtime/common_lib.cc:341 I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda': I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found. W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. /home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives import distutils as _distutils 2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [ OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices ---------------------------------------------------------------------- Ran 1 test in 2.549s OK In tests/cross_aot_test.py class JaxAotTest(jtu.JaxTestCase): def test_tpu_profiler_registered_get_topology_from_devices(self): try: _ = topologies.get_topology_desc( topology_name='fake_topology', platform='tpu', ) except xla_extension.XlaRuntimeError: logging.info('Expected to fail to get topology') with tempfile.TemporaryDirectory() as tmpdir: try: jax.profiler.start_trace(tmpdir) jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( jnp.ones(jax.local_device_count()) ) finally: jax.profiler.stop_trace() proto_path = glob.glob( os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True ) self.assertLen(proto_path, 1) with open(proto_path[0], 'rb') as f: proto = f.read() # Sanity check that serialized proto contains host, and Python traces # without deserializing. self.assertIn(b'/host:metadata', proto) if jtu.test_device_matches(['tpu']): self.assertNotIn(b'/device:TPU', proto) self.assertIn(b'pxla.py', proto) PiperOrigin-RevId: 633076007 13 May 2024, 03:50:41 UTC
af4bddb [Mosaic GPU] Correct TileTransform.transform_index() Previously TileTransform.transform_index() would transform like so x, y, z -> x, y // t_x, z // t_y, 0, 0 But it acutally should be x, y, z -> x, y // t_x, z // t_y, y % t_x, z % t_y PiperOrigin-RevId: 633052067 13 May 2024, 01:23:25 UTC
dac5b75 [jax:mosaic-gpu] Test type conversion for tiled fragment array PiperOrigin-RevId: 633052027 13 May 2024, 01:23:07 UTC
c1f184d [Mosaic GPU] Cleanup matmul example mainly by removing the m/n block loops (1 step). PiperOrigin-RevId: 633051910 13 May 2024, 01:20:06 UTC
3e3d916 Update XLA dependency to use revision http://github.com/openxla/xla/commit/8dd9fc35394a2b98d249f872242a100d7d56286e. PiperOrigin-RevId: 633049132 13 May 2024, 01:02:26 UTC
b4f2145 Update XLA dependency to use revision http://github.com/openxla/xla/commit/d3e881ad668b7aa44283d47bb553a04b86b71315. PiperOrigin-RevId: 632860505 12 May 2024, 01:00:43 UTC
a527b71 [Mosaic GPU] Prepare for writing warp-specialized kernels PiperOrigin-RevId: 632854287 12 May 2024, 00:09:08 UTC
49bd4d6 Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a PiperOrigin-RevId: 632818524 11 May 2024, 19:24:06 UTC
93dfe05 Implements Ragged Dot API 11 May 2024, 13:40:18 UTC
3b03e54 Raise a runtime error when trying to convert the `jax.Array` wrapped by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape. PiperOrigin-RevId: 632682906 11 May 2024, 04:08:06 UTC
20646eb Update XLA dependency to use revision http://github.com/openxla/xla/commit/0b3dc68410d57f6cd3d0a484f46946d6d12f03a8. PiperOrigin-RevId: 632655986 11 May 2024, 01:19:21 UTC
9ac1d38 Finish jax and jaxlib 0.4.28 release PiperOrigin-RevId: 632653310 11 May 2024, 01:06:52 UTC
979d9ca Merge pull request #21168 from 8bitmp3:upgrade-sharded--doc PiperOrigin-RevId: 632648408 11 May 2024, 00:44:15 UTC
a4693db Add a jaxpr interpreter for propagating memory kinds to output. It only triggers if we detect multiple memory kinds in the jaxpr. This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals. It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later. 1. `jit(f, out_shardings=s_host)` 2. ``` @jax.jit def f(x): return jax.device_put(x, s_host) ``` PiperOrigin-RevId: 632621025 10 May 2024, 22:34:57 UTC
27c932a Do not import from lowering in tests/pallas/pallas_test.py This ensures that the test is importable even with a non-GPU jaxlib, which does not have Triton dialect bindings. PiperOrigin-RevId: 632603225 10 May 2024, 21:25:17 UTC
9ea3fcb Upgrade JAX Parallelism Sharded Computation 101 doc 10 May 2024, 21:24:16 UTC
17444fc Merge pull request #21174 from hawkinsp:spmm PiperOrigin-RevId: 632589433 10 May 2024, 20:35:04 UTC
dda428e Disable tests that trigger warning if x64 mode isn't enabled. 10 May 2024, 19:58:22 UTC
c3cab2e Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3 PiperOrigin-RevId: 632579101 10 May 2024, 19:56:10 UTC
c231cd5 Merge pull request #21173 from hawkinsp:precision PiperOrigin-RevId: 632577567 10 May 2024, 19:50:07 UTC
24b4731 Force float32 matmuls in examples_test. This test started failing when we changed our CI to use L4 GPUs. Using highest precision resolves the problem. 10 May 2024, 19:30:02 UTC
0a3e432 [PJRT C API] Enable PJRT C API runtime in jax2tf dlpack. GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet. PiperOrigin-RevId: 632555239 10 May 2024, 18:30:37 UTC
6c42533 Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306 PiperOrigin-RevId: 632548226 10 May 2024, 18:06:38 UTC
586568f Simplify JAX lowering rules for cumulative sum Rely on XLA decomposition. # JAX GPU microbenchmarks 285us for cumsum over 1e8 elements 449us for cumsum over 1e8 elements. # JAX CPU microbenchmarks: 1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements PiperOrigin-RevId: 632547166 10 May 2024, 18:03:28 UTC
13a1955 Merge pull request #21167 from jakevdp:einsum-path-func PiperOrigin-RevId: 632538144 10 May 2024, 17:35:32 UTC
ebb9184 Disable bfloat16 on long seq lengths for splash attention kernel test 10 May 2024, 17:31:29 UTC
bac3a6f Allow tokens being passed to `jit` and through dispatch and being returned from the jitted function. Fixes https://github.com/google/jax/issues/21160 PiperOrigin-RevId: 632531105 10 May 2024, 17:12:48 UTC
0267ed0 Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib. The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved. PiperOrigin-RevId: 632508121 10 May 2024, 15:48:12 UTC
d07951c jnp.einsum_path: improve docs & annotations 10 May 2024, 15:39:32 UTC
c2d78ab Merge pull request #21148 from jakevdp:einsum-path PiperOrigin-RevId: 632470123 10 May 2024, 12:58:30 UTC
2a01541 Merge pull request #21157 from djm3622:patch-1 PiperOrigin-RevId: 632469962 10 May 2024, 12:55:38 UTC
c3d3db9 jnp.einsum: support optimize=False, and improve docs for this keyword. 10 May 2024, 02:50:06 UTC
aa48a55 Update line127 error, debugging.md There is an error with "y, z = jnp.sin(x), jnp.cos(x)" where jnp.cos(x) was nested within jnp.sin(x) ==> jnp.sin(x, jnp.cos(x)). This caused an error to be thrown. This change fixes that. 10 May 2024, 02:44:27 UTC
f21e3e8 Update XLA dependency to use revision http://github.com/openxla/xla/commit/e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4. PiperOrigin-RevId: 632333365 10 May 2024, 01:37:25 UTC
6f79093 [XLA:TPU] Support output streaming and refactor TryOutputStreaming into a bottoms-up approach. Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph. PiperOrigin-RevId: 632318740 10 May 2024, 00:34:21 UTC
a9460f2 Merge pull request #21130 from Micky774:reshape PiperOrigin-RevId: 632277406 09 May 2024, 21:51:56 UTC
e8ac011 Allow nested shard_map. PiperOrigin-RevId: 632275515 09 May 2024, 21:45:40 UTC
79005c1 Deprecate newshape argument of jnp.reshape 09 May 2024, 21:02:07 UTC
1cb6971 Merge pull request #21107 from justinjfu/pallas_splash_test_fix Fix failing smem_hbm_dma test 09 May 2024, 20:27:56 UTC
7245714 Add zeros initialization to failing smem-hbm copy test. 09 May 2024, 20:19:21 UTC
back to top