b552c15 | Justin Fu | 15 May 2024, 19:54:17 UTC | [Pallas] Add temporary skip for pallas out-of-bounds interpret mode test on GPU. PiperOrigin-RevId: 634046352 | 15 May 2024, 20:51:37 UTC |
6a949ae | jax authors | 15 May 2024, 20:43:18 UTC | Merge pull request #21251 from jakevdp:fix-pmap-warning PiperOrigin-RevId: 634061434 | 15 May 2024, 20:43:18 UTC |
f5a59cc | Jake VanderPlas | 15 May 2024, 20:20:05 UTC | Fix warning filter in test_pmap_of_prng_key | 15 May 2024, 20:20:05 UTC |
1538197 | jax authors | 15 May 2024, 19:53:19 UTC | Merge pull request #21247 from dfm:gh20481 PiperOrigin-RevId: 634045828 | 15 May 2024, 19:53:19 UTC |
bb40075 | jax authors | 15 May 2024, 19:48:37 UTC | [Mosaic GPU] Fix argument checking in case the input shape is not a tuple. PiperOrigin-RevId: 634044759 | 15 May 2024, 19:49:27 UTC |
09a4b38 | Dan Foreman-Mackey | 15 May 2024, 19:01:01 UTC | Add informative error for invalid unroll in scan As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an uninformative `ZeroDivisionError`. This PR adds a check which raises a `ValueError` for `unroll<=0`. | 15 May 2024, 19:40:27 UTC |
4ccac4c | jax authors | 15 May 2024, 17:40:58 UTC | Merge pull request #21237 from gnecula:clean_tokens PiperOrigin-RevId: 634003364 | 15 May 2024, 17:40:58 UTC |
11a4513 | jax authors | 15 May 2024, 17:09:47 UTC | [Mosaic GPU] Explicitly check the smem size so that we get a good error rather than a cryptic cuda error PiperOrigin-RevId: 633993345 | 15 May 2024, 17:10:46 UTC |
874a20a | jax authors | 15 May 2024, 15:52:16 UTC | [Mosaic GPU] Fix np.random API mistake in tests PiperOrigin-RevId: 633970332 | 15 May 2024, 15:53:03 UTC |
7653db8 | Peter Hawkins | 15 May 2024, 14:37:24 UTC | Fix CI test failure in line_search_test. A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not. PiperOrigin-RevId: 633949879 | 15 May 2024, 14:38:20 UTC |
41153b1 | George Necula | 14 May 2024, 10:20:58 UTC | Cleanup token handling during lowering Version 0.4.27 of jaxlib is now the minimum version and it supports real stablehlo tokens as module inputs and outputs. Hence we can now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs `create_tokens` and `replace_tokens_with_dummy` (both of them are always False now). We also remove `num_output_tokens` that is not used. | 15 May 2024, 12:54:46 UTC |
66a92c4 | Yue Sheng | 15 May 2024, 05:40:53 UTC | Reverts 9e7830df2df9362edcf2e18e353d327fdecae678 PiperOrigin-RevId: 633816901 | 15 May 2024, 05:41:44 UTC |
b8cec53 | jax authors | 15 May 2024, 01:41:41 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a0f5d76e3dac2ee3293de60a79c29a46668a53b0. PiperOrigin-RevId: 633771201 | 15 May 2024, 01:42:27 UTC |
9af0a31 | jax authors | 14 May 2024, 23:05:17 UTC | Merge pull request #21222 from tempoxylophone:patch-1 PiperOrigin-RevId: 633734390 | 14 May 2024, 23:05:17 UTC |
48b9018 | jax authors | 14 May 2024, 23:04:56 UTC | Merge pull request #21221 from nouiz:doc PiperOrigin-RevId: 633734299 | 14 May 2024, 23:04:56 UTC |
1dff5b3 | jax authors | 14 May 2024, 23:01:50 UTC | Merge pull request #21226 from Micky774:linalg-depr PiperOrigin-RevId: 633734284 | 14 May 2024, 23:01:50 UTC |
d4f36e2 | jax authors | 14 May 2024, 22:48:53 UTC | Merge pull request #21227 from jakevdp:eye-offset PiperOrigin-RevId: 633730820 | 14 May 2024, 22:48:53 UTC |
e8b06cc | jax authors | 14 May 2024, 22:20:37 UTC | Cholesky rank-1 update kernel for JAX. PiperOrigin-RevId: 633722940 | 14 May 2024, 22:21:38 UTC |
e2918ca | Sergei Lebedev | 14 May 2024, 21:47:24 UTC | Added a very rough sketch of Mosaic GPU lowering for Pallas Almost nothing is supported, including * PyTree inputs/outputs * indexers * non-trivial grids * block specs * any primitives beyond the ones added here * etc etc PiperOrigin-RevId: 633713366 | 14 May 2024, 21:48:09 UTC |
3024c78 | Jake VanderPlas | 14 May 2024, 20:32:54 UTC | jnp.eye: allow k to be dynamic | 14 May 2024, 20:32:54 UTC |
0ad5167 | Tomás Longeri | 14 May 2024, 19:54:02 UTC | Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support) PiperOrigin-RevId: 633677477 | 14 May 2024, 19:54:48 UTC |
5cc255b | Meekail Zain | 14 May 2024, 19:53:54 UTC | Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv | 14 May 2024, 19:53:54 UTC |
0501d3d | jax authors | 14 May 2024, 18:29:07 UTC | Merge pull request #21223 from jakevdp:kl-div-doc PiperOrigin-RevId: 633651462 | 14 May 2024, 18:29:07 UTC |
bb5787d | Jake VanderPlas | 14 May 2024, 17:39:43 UTC | Finalize deprecations of several APIs PiperOrigin-RevId: 633634215 | 14 May 2024, 17:40:40 UTC |
1d6ffde | Ashish Shenoy | 14 May 2024, 17:06:46 UTC | Reverts 85e91c2be4310d9728f7bfeefef921ee4a075135 PiperOrigin-RevId: 633622856 | 14 May 2024, 17:07:44 UTC |
c3c2393 | Jake VanderPlas | 14 May 2024, 16:25:04 UTC | kl_div: fix incorrect formula in doc | 14 May 2024, 16:25:04 UTC |
1e3a4b5 | Jake VanderPlas | 14 May 2024, 16:03:11 UTC | Remove now-empty jax/_src/third_party/numpy Followup to https://github.com/google/jax/pull/21119 PiperOrigin-RevId: 633603237 | 14 May 2024, 16:03:54 UTC |
4bcad15 | tempoxylophone | 14 May 2024, 15:52:41 UTC | Fix type in advanced-autodiff.md removed `doc.g` text from the first section. | 14 May 2024, 15:52:41 UTC |
f578d78 | Frederic Bastien | 14 May 2024, 15:35:01 UTC | Update doc with the other error that can be thrown. | 14 May 2024, 15:35:01 UTC |
aea6b32 | jax authors | 14 May 2024, 14:48:17 UTC | Merge pull request #21216 from jakevdp:pmap-prng PiperOrigin-RevId: 633582800 | 14 May 2024, 14:48:17 UTC |
95c4ba9 | jax authors | 14 May 2024, 10:45:05 UTC | [JAX] Use first process id instead of process 0 to share multi-host data. PiperOrigin-RevId: 633526220 | 14 May 2024, 10:45:41 UTC |
84774b3 | Sergey Kozub | 14 May 2024, 10:07:12 UTC | Fix sparse dot metadata loader Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs. PiperOrigin-RevId: 633513110 | 14 May 2024, 10:08:18 UTC |
5150cfe | Jake VanderPlas | 14 May 2024, 02:04:22 UTC | Fix PRNGKey handling under jit-of-pmap | 14 May 2024, 02:04:22 UTC |
e735a00 | jax authors | 14 May 2024, 01:57:48 UTC | Merge pull request #21215 from jakevdp:setops-docs PiperOrigin-RevId: 633400796 | 14 May 2024, 01:57:48 UTC |
98af786 | jax authors | 14 May 2024, 01:09:55 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/2365b5180382662ee10bd3643fce45a66535f407. PiperOrigin-RevId: 633391316 | 14 May 2024, 01:10:33 UTC |
a56eb56 | jax authors | 14 May 2024, 00:53:26 UTC | Merge pull request #21211 from elfiegg:main PiperOrigin-RevId: 633386931 | 14 May 2024, 00:53:26 UTC |
6157d8e | jax authors | 14 May 2024, 00:34:54 UTC | Merge pull request #21176 from jakevdp:shard-map-api-docs PiperOrigin-RevId: 633382922 | 14 May 2024, 00:34:54 UTC |
7ed7780 | Jake VanderPlas | 13 May 2024, 22:19:14 UTC | Improve docs for jax.numpy set-like operations | 13 May 2024, 22:19:14 UTC |
cd6e012 | Junwhan Ahn | 13 May 2024, 21:36:43 UTC | Enable JAX memory tests for GPUs and CPUs PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now. PiperOrigin-RevId: 633335638 | 13 May 2024, 21:37:37 UTC |
72e9eb9 | Jake VanderPlas | 13 May 2024, 20:04:15 UTC | shard_map: add API docs | 13 May 2024, 20:04:15 UTC |
72a81e5 | Peter Hawkins | 13 May 2024, 19:34:09 UTC | Readd a default lowering rule for cumsum et al. A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules PiperOrigin-RevId: 633297839 | 13 May 2024, 19:34:51 UTC |
1e48adc | Justin Fu | 13 May 2024, 18:49:51 UTC | [Pallas] Pad input/outputs in interpret mode to fix errors in OOB memory accesses. PiperOrigin-RevId: 633283991 | 13 May 2024, 18:50:21 UTC |
b8ed346 | jax authors | 13 May 2024, 18:43:24 UTC | Merge pull request #21119 from jakevdp:linalg-cond PiperOrigin-RevId: 633281675 | 13 May 2024, 18:43:24 UTC |
6189a55 | jax authors | 13 May 2024, 18:18:09 UTC | Merge pull request #21048 from jakevdp:np-squeeze-doc PiperOrigin-RevId: 633273019 | 13 May 2024, 18:18:09 UTC |
9e7830d | Yue Sheng | 13 May 2024, 17:52:31 UTC | Async dispatch expensive computations on the JAX CPU backend. By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior. PiperOrigin-RevId: 633264117 | 13 May 2024, 17:53:09 UTC |
54d4072 | Yash Katariya | 13 May 2024, 17:48:04 UTC | Populate `propagated_out_mem_kinds` inside the branch where it's needed PiperOrigin-RevId: 633262630 | 13 May 2024, 17:48:52 UTC |
1f6d902 | Jake VanderPlas | 13 May 2024, 17:36:50 UTC | jnp.linalg.cond: improve implementation & docs | 13 May 2024, 17:36:50 UTC |
85e91c2 | jax authors | 13 May 2024, 17:23:28 UTC | Merge pull request #21203 from gnecula:export_device_poly PiperOrigin-RevId: 633253709 | 13 May 2024, 17:23:28 UTC |
7ceae95 | Jake VanderPlas | 03 May 2024, 12:37:59 UTC | Better documentation for several jax.numpy functions | 13 May 2024, 17:09:52 UTC |
98aead7 | George Necula | 13 May 2024, 11:32:24 UTC | [export] Relax the check that exported modules are used with same number of devices as when exported Now we allow a module exported for 1 device and not using any sharding annotations to be called from a computation that uses multiple devices. Such exported modules can be parallelized trivially point-wise. | 13 May 2024, 17:09:43 UTC |
43d1916 | Elfie Guo | 13 May 2024, 16:50:52 UTC | Remove type promotion for mixed fp8 matmuls. | 13 May 2024, 16:50:52 UTC |
e4f3b3f | Justin Fu | 13 May 2024, 16:34:39 UTC | Merge pull request #21169 from justinjfu/splash_precision_fix Disable bfloat16 on long seq lengths for splash attention kernel test | 13 May 2024, 16:34:39 UTC |
35a512d | Jake VanderPlas | 13 May 2024, 16:18:04 UTC | CI: update NumPy build version to 2.0.0rc2 PiperOrigin-RevId: 633233231 | 13 May 2024, 16:18:40 UTC |
de14e3b | George Karpenkov | 13 May 2024, 15:34:53 UTC | Reverts 49bd4d6f01d6cda00f9b1bdfbda156636baae928 PiperOrigin-RevId: 633221195 | 13 May 2024, 15:35:40 UTC |
e66a234 | jax authors | 13 May 2024, 12:48:13 UTC | Merge pull request #21191 from gnecula:export_simplify PiperOrigin-RevId: 633179742 | 13 May 2024, 12:48:13 UTC |
54ca3d4 | jax authors | 13 May 2024, 12:30:57 UTC | Merge pull request #21202 from superbobry:pallas PiperOrigin-RevId: 633176367 | 13 May 2024, 12:30:57 UTC |
1fed784 | jax authors | 13 May 2024, 12:03:28 UTC | Merge pull request #20940 from piotrfilipiuk:changelist/623910451 PiperOrigin-RevId: 633170419 | 13 May 2024, 12:03:28 UTC |
78d4d0a | George Necula | 11 May 2024, 12:19:31 UTC | [export] Simplify export internals, prepare for integration with AOT APIs In preparation for a better integration of the jax.experimental.export with the AOT APIs, we make several simplifications: * turn on always the generation of shape assertions in presence of shape polymorphism. Previously, shape assertions were turned on unless the serialization version was less than 7 (possible only before March 27th, 2024 when the minimum serialization version was bumped to 9), or if the user specified explicitly that shape assertions should be turned off. It is not safe to turn off shape assertions and I am not aware of an instance where somebody had to turn them off, except for temporary debugging. We keep the `DisabledSafetyCheck.shape_assertions` API for now, for backwards compatibility, but it has no effect and it emits a deprecation warning. * remove the code that was conditional on the serialization version being less than 9, e.g., for the lowering in presence of effects. * remove a safety check that ensures that when `export` is used on JAX callables, i.e., not the result of `jax.jit`, the code should not contain non-replicated sharding annotations. This usage of `export` is rare and will be removed once `export` will be integrated with the AOT APIs. * remove code that was needed only for older jaxlib to replace_tokens_with_dummy. | 13 May 2024, 11:41:51 UTC |
8094d0d | Sergei Lebedev | 13 May 2024, 11:23:18 UTC | Guarded Pallas GPU import in tests/pallas/pallas_test.py We do not build Triton IR bindings on Windows. This should fix https://github.com/google/jax/actions/runs/9051189315/job/24867428634. | 13 May 2024, 11:23:18 UTC |
1c6855a | Sergei Lebedev | 13 May 2024, 10:06:31 UTC | Ensured that all Pallas GPU tests depend on :pallas_gpu This dependency is added implicitly by Google-internal infra, but we need it to be explicit for Bazel builds to avoid ImportErrors at lowering time. PiperOrigin-RevId: 633147268 | 13 May 2024, 10:07:22 UTC |
ba8480a | Jieying Luo | 13 May 2024, 03:49:51 UTC | Register TPU profiler plugin when get_topology_desc is called with tpu platform. This allows the TPU profiler to work with other plugin backends. Tested on a GPU VM: $ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html $ pip install -e . $ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py Running tests under Python 3.10.12: /usr/bin/python [ RUN ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized. === Source Location Trace: === learning/45eac/tfrc/runtime/common_lib.cc:285 learning/45eac/tfrc/runtime/common_lib.cc:341 I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda': I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found. W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. /home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives import distutils as _distutils 2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [ OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices ---------------------------------------------------------------------- Ran 1 test in 2.549s OK In tests/cross_aot_test.py class JaxAotTest(jtu.JaxTestCase): def test_tpu_profiler_registered_get_topology_from_devices(self): try: _ = topologies.get_topology_desc( topology_name='fake_topology', platform='tpu', ) except xla_extension.XlaRuntimeError: logging.info('Expected to fail to get topology') with tempfile.TemporaryDirectory() as tmpdir: try: jax.profiler.start_trace(tmpdir) jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( jnp.ones(jax.local_device_count()) ) finally: jax.profiler.stop_trace() proto_path = glob.glob( os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True ) self.assertLen(proto_path, 1) with open(proto_path[0], 'rb') as f: proto = f.read() # Sanity check that serialized proto contains host, and Python traces # without deserializing. self.assertIn(b'/host:metadata', proto) if jtu.test_device_matches(['tpu']): self.assertNotIn(b'/device:TPU', proto) self.assertIn(b'pxla.py', proto) PiperOrigin-RevId: 633076007 | 13 May 2024, 03:50:41 UTC |
af4bddb | jax authors | 13 May 2024, 01:20:48 UTC | [Mosaic GPU] Correct TileTransform.transform_index() Previously TileTransform.transform_index() would transform like so x, y, z -> x, y // t_x, z // t_y, 0, 0 But it acutally should be x, y, z -> x, y // t_x, z // t_y, y % t_x, z % t_y PiperOrigin-RevId: 633052067 | 13 May 2024, 01:23:25 UTC |
dac5b75 | jax authors | 13 May 2024, 01:20:24 UTC | [jax:mosaic-gpu] Test type conversion for tiled fragment array PiperOrigin-RevId: 633052027 | 13 May 2024, 01:23:07 UTC |
c1f184d | jax authors | 13 May 2024, 01:19:20 UTC | [Mosaic GPU] Cleanup matmul example mainly by removing the m/n block loops (1 step). PiperOrigin-RevId: 633051910 | 13 May 2024, 01:20:06 UTC |
3e3d916 | jax authors | 13 May 2024, 01:01:44 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/8dd9fc35394a2b98d249f872242a100d7d56286e. PiperOrigin-RevId: 633049132 | 13 May 2024, 01:02:26 UTC |
b4f2145 | jax authors | 12 May 2024, 00:59:51 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/d3e881ad668b7aa44283d47bb553a04b86b71315. PiperOrigin-RevId: 632860505 | 12 May 2024, 01:00:43 UTC |
a527b71 | Adam Paszke | 12 May 2024, 00:08:24 UTC | [Mosaic GPU] Prepare for writing warp-specialized kernels PiperOrigin-RevId: 632854287 | 12 May 2024, 00:09:08 UTC |
49bd4d6 | Peter Hawkins | 11 May 2024, 19:23:23 UTC | Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a PiperOrigin-RevId: 632818524 | 11 May 2024, 19:24:06 UTC |
93dfe05 | piotrfilipiuk | 11 May 2024, 13:40:18 UTC | Implements Ragged Dot API | 11 May 2024, 13:40:18 UTC |
3b03e54 | Yue Sheng | 11 May 2024, 04:07:18 UTC | Raise a runtime error when trying to convert the `jax.Array` wrapped by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape. PiperOrigin-RevId: 632682906 | 11 May 2024, 04:08:06 UTC |
20646eb | jax authors | 11 May 2024, 01:18:33 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/0b3dc68410d57f6cd3d0a484f46946d6d12f03a8. PiperOrigin-RevId: 632655986 | 11 May 2024, 01:19:21 UTC |
9ac1d38 | Jake VanderPlas | 11 May 2024, 01:06:08 UTC | Finish jax and jaxlib 0.4.28 release PiperOrigin-RevId: 632653310 | 11 May 2024, 01:06:52 UTC |
979d9ca | jax authors | 11 May 2024, 00:44:15 UTC | Merge pull request #21168 from 8bitmp3:upgrade-sharded--doc PiperOrigin-RevId: 632648408 | 11 May 2024, 00:44:15 UTC |
a4693db | Yash Katariya | 10 May 2024, 22:34:03 UTC | Add a jaxpr interpreter for propagating memory kinds to output. It only triggers if we detect multiple memory kinds in the jaxpr. This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals. It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later. 1. `jit(f, out_shardings=s_host)` 2. ``` @jax.jit def f(x): return jax.device_put(x, s_host) ``` PiperOrigin-RevId: 632621025 | 10 May 2024, 22:34:57 UTC |
27c932a | Sergei Lebedev | 10 May 2024, 21:24:29 UTC | Do not import from lowering in tests/pallas/pallas_test.py This ensures that the test is importable even with a non-GPU jaxlib, which does not have Triton dialect bindings. PiperOrigin-RevId: 632603225 | 10 May 2024, 21:25:17 UTC |
9ea3fcb | 8bitmp3 | 10 May 2024, 15:52:29 UTC | Upgrade JAX Parallelism Sharded Computation 101 doc | 10 May 2024, 21:24:16 UTC |
17444fc | jax authors | 10 May 2024, 20:35:04 UTC | Merge pull request #21174 from hawkinsp:spmm PiperOrigin-RevId: 632589433 | 10 May 2024, 20:35:04 UTC |
dda428e | Peter Hawkins | 10 May 2024, 19:39:01 UTC | Disable tests that trigger warning if x64 mode isn't enabled. | 10 May 2024, 19:58:22 UTC |
c3cab2e | jax authors | 10 May 2024, 19:55:17 UTC | Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3 PiperOrigin-RevId: 632579101 | 10 May 2024, 19:56:10 UTC |
c231cd5 | jax authors | 10 May 2024, 19:50:07 UTC | Merge pull request #21173 from hawkinsp:precision PiperOrigin-RevId: 632577567 | 10 May 2024, 19:50:07 UTC |
24b4731 | Peter Hawkins | 10 May 2024, 17:47:43 UTC | Force float32 matmuls in examples_test. This test started failing when we changed our CI to use L4 GPUs. Using highest precision resolves the problem. | 10 May 2024, 19:30:02 UTC |
0a3e432 | Jieying Luo | 10 May 2024, 18:29:49 UTC | [PJRT C API] Enable PJRT C API runtime in jax2tf dlpack. GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet. PiperOrigin-RevId: 632555239 | 10 May 2024, 18:30:37 UTC |
6c42533 | Peter Hawkins | 10 May 2024, 18:05:47 UTC | Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306 PiperOrigin-RevId: 632548226 | 10 May 2024, 18:06:38 UTC |
586568f | George Karpenkov | 10 May 2024, 18:02:47 UTC | Simplify JAX lowering rules for cumulative sum Rely on XLA decomposition. # JAX GPU microbenchmarks 285us for cumsum over 1e8 elements 449us for cumsum over 1e8 elements. # JAX CPU microbenchmarks: 1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements PiperOrigin-RevId: 632547166 | 10 May 2024, 18:03:28 UTC |
13a1955 | jax authors | 10 May 2024, 17:35:32 UTC | Merge pull request #21167 from jakevdp:einsum-path-func PiperOrigin-RevId: 632538144 | 10 May 2024, 17:35:32 UTC |
ebb9184 | Justin Fu | 10 May 2024, 17:31:29 UTC | Disable bfloat16 on long seq lengths for splash attention kernel test | 10 May 2024, 17:31:29 UTC |
bac3a6f | Yash Katariya | 10 May 2024, 17:11:55 UTC | Allow tokens being passed to `jit` and through dispatch and being returned from the jitted function. Fixes https://github.com/google/jax/issues/21160 PiperOrigin-RevId: 632531105 | 10 May 2024, 17:12:48 UTC |
0267ed0 | jax authors | 10 May 2024, 15:47:22 UTC | Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib. The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved. PiperOrigin-RevId: 632508121 | 10 May 2024, 15:48:12 UTC |
d07951c | Jake VanderPlas | 10 May 2024, 15:39:32 UTC | jnp.einsum_path: improve docs & annotations | 10 May 2024, 15:39:32 UTC |
c2d78ab | jax authors | 10 May 2024, 12:58:30 UTC | Merge pull request #21148 from jakevdp:einsum-path PiperOrigin-RevId: 632470123 | 10 May 2024, 12:58:30 UTC |
2a01541 | jax authors | 10 May 2024, 12:55:38 UTC | Merge pull request #21157 from djm3622:patch-1 PiperOrigin-RevId: 632469962 | 10 May 2024, 12:55:38 UTC |
c3d3db9 | Jake VanderPlas | 10 May 2024, 02:50:06 UTC | jnp.einsum: support optimize=False, and improve docs for this keyword. | 10 May 2024, 02:50:06 UTC |
aa48a55 | djm3622 | 10 May 2024, 02:44:27 UTC | Update line127 error, debugging.md There is an error with "y, z = jnp.sin(x), jnp.cos(x)" where jnp.cos(x) was nested within jnp.sin(x) ==> jnp.sin(x, jnp.cos(x)). This caused an error to be thrown. This change fixes that. | 10 May 2024, 02:44:27 UTC |
f21e3e8 | jax authors | 10 May 2024, 01:36:35 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4. PiperOrigin-RevId: 632333365 | 10 May 2024, 01:37:25 UTC |
6f79093 | Jackson Stokes | 10 May 2024, 00:33:20 UTC | [XLA:TPU] Support output streaming and refactor TryOutputStreaming into a bottoms-up approach. Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph. PiperOrigin-RevId: 632318740 | 10 May 2024, 00:34:21 UTC |
a9460f2 | jax authors | 09 May 2024, 21:51:56 UTC | Merge pull request #21130 from Micky774:reshape PiperOrigin-RevId: 632277406 | 09 May 2024, 21:51:56 UTC |
e8ac011 | Parker Schuh | 09 May 2024, 21:44:37 UTC | Allow nested shard_map. PiperOrigin-RevId: 632275515 | 09 May 2024, 21:45:40 UTC |
79005c1 | Meekail Zain | 09 May 2024, 21:02:07 UTC | Deprecate newshape argument of jnp.reshape | 09 May 2024, 21:02:07 UTC |
1cb6971 | Justin Fu | 09 May 2024, 20:27:56 UTC | Merge pull request #21107 from justinjfu/pallas_splash_test_fix Fix failing smem_hbm_dma test | 09 May 2024, 20:27:56 UTC |
7245714 | Justin Fu | 07 May 2024, 16:52:02 UTC | Add zeros initialization to failing smem-hbm copy test. | 09 May 2024, 20:19:21 UTC |