swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
2873271 Use get_api_version to choose gather/scatter Instead of using try/except - it's cleaner. PiperOrigin-RevId: 638835261 11 June 2024, 18:03:09 UTC
8199267 Merge pull request #21762 from rajasekharporeddy:testbranch2 PiperOrigin-RevId: 642310572 11 June 2024, 17:19:00 UTC
c5761b7 Merge pull request #21802 from superbobry:build PiperOrigin-RevId: 642301907 11 June 2024, 16:57:04 UTC
5b38549 [XLA:Mosaic] No need to assume a multiple of tile if tile dim size is 1. PiperOrigin-RevId: 642301822 11 June 2024, 16:53:13 UTC
27140fe Merge pull request #21772 from jakevdp:beta-dep PiperOrigin-RevId: 642275316 11 June 2024, 15:15:58 UTC
c6666e2 Merge pull request #21788 from jakevdp:top-k-doc PiperOrigin-RevId: 642272680 11 June 2024, 15:06:47 UTC
1256ceb [Mosaic GPU] Rearrange the pass pipeline (again) PiperOrigin-RevId: 642256145 11 June 2024, 13:59:50 UTC
3345952 Merge pull request #21803 from gnecula:export_typing PiperOrigin-RevId: 642251107 11 June 2024, 13:38:11 UTC
2aa2a39 Update code examples 11 June 2024, 12:17:36 UTC
1c06114 Merge pull request #21729 from superbobry:pallas PiperOrigin-RevId: 642225089 11 June 2024, 11:54:09 UTC
f147dd2 Merge pull request #21800 from superbobry:typing PiperOrigin-RevId: 642224964 11 June 2024, 11:50:09 UTC
e3faf85 [export] Cleaned up types of [in|out]_shardings Previously we declared Exported.in_shardings to be a sequence of `core.AbstractValue`, but in reality we only support `core.ShapedArray`. We change the type declaration and this allowed us to clean up some `# type: ignore" 11 June 2024, 11:46:44 UTC
e8f20ad Removed unused ``cuda_options`` from ``lower_jaxpr_to_triton_module`` I also re-enabled mypy in triton/pallas_call_registration.py as a drive by change. 11 June 2024, 11:27:18 UTC
70f6ab3 Updated the type annotations of *_spec= parameters of pl.pallas_call The previous type did not work for nested pytrees and for some reason neither pytype nor mypy flagged that. I also re-enabled type checking for most pallas/*.py files. 11 June 2024, 11:22:00 UTC
3b1b5fd Added filelock to test-requirements.txt and requirements lock files This is a follow up to #21741. 11 June 2024, 10:53:10 UTC
f847350 Removed kernel_regeneration_util from Mosaic It was only used for persisting kernel metadata, and that can be done via jax.named_scope instead. PiperOrigin-RevId: 642195336 11 June 2024, 09:36:41 UTC
11370b7 Merge pull request #21782 from jakevdp:rel-entr PiperOrigin-RevId: 642094313 11 June 2024, 01:51:22 UTC
02b5d47 Swap operands of dot if the LHS is fed by a parameter PiperOrigin-RevId: 642090766 11 June 2024, 01:33:05 UTC
9439f63 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap). PiperOrigin-RevId: 642085019 11 June 2024, 01:08:19 UTC
3d4ee0d Merge pull request #21791 from jakevdp:remove-deprecated PiperOrigin-RevId: 642068297 10 June 2024, 23:58:39 UTC
266028f Remove unused variable 10 June 2024, 23:30:44 UTC
956226c Raise an error if device_put sees an invalid value. PiperOrigin-RevId: 642053543 10 June 2024, 23:07:44 UTC
71c19b7 Rewrite `vector.contraction` with bf16 accumulator and output into a contraction with f32 accumulator and output, where the accumulator is extended and the output truncated. For targets that do not support bf16 matmul, the lhs and rhs are extended to f32. PiperOrigin-RevId: 642051952 10 June 2024, 23:02:46 UTC
9d9dd36 Adds test_compute_no_inputs_host_replicated in memories_test.py PiperOrigin-RevId: 642033992 10 June 2024, 22:02:34 UTC
bb24a92 Update XLA dependency to use revision http://github.com/openxla/xla/commit/af7fe24506bc1d988aa326a2db383a13071145d6. PiperOrigin-RevId: 642026581 10 June 2024, 21:38:29 UTC
6b8e2f3 DOC: jax.lax.top_k: fix docstring rendering & add example 10 June 2024, 20:57:21 UTC
af00430 Merge pull request #21516 from nouiz:paralell_computation PiperOrigin-RevId: 642004618 10 June 2024, 20:29:10 UTC
27de854 Merge pull request #21781 from hawkinsp:release PiperOrigin-RevId: 641994356 10 June 2024, 19:56:31 UTC
489febe Enable input fusion for a specific kernel pattern. cl/640530524 introduces batching support for some pallas calls that don't currently support it yet using dynamic slicing the input and dynamically updating the output. This CL ensures that XLA-guided input fusion into pallas kernel is working as expected for such pattern. We don't have support for fusion on the output side yet for pallas kernels. PiperOrigin-RevId: 641989012 10 June 2024, 19:37:49 UTC
f4dfa84 Merge pull request #21774 from jakevdp:tree-all-is-leaf PiperOrigin-RevId: 641978173 10 June 2024, 19:01:05 UTC
53daa0c [XLA:Mosaic] Fix infer layout for nested loop. - We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block. - Refactor the logic of assume layouts for block arguments to a helper function. - Add tests for nested fori loop and while loop. PiperOrigin-RevId: 641973011 10 June 2024, 18:49:01 UTC
f6ce973 Merge pull request #21745 from pkgoogle:better_right_shift_doc PiperOrigin-RevId: 641972495 10 June 2024, 18:45:38 UTC
a073476 chore: adopt new local wheel installation logic PiperOrigin-RevId: 641972325 10 June 2024, 18:41:52 UTC
6fa31e5 Update version numbers after v0.4.29 release. 10 June 2024, 18:37:53 UTC
afe088f Simplify definition of jax.scipy.special.kl_div 10 June 2024, 18:36:35 UTC
3fe7377 Merge pull request #21763 from gnecula:export_api PiperOrigin-RevId: 641959833 10 June 2024, 18:05:34 UTC
b33aca6 [export] Create the jax.export module APIs. The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`. 10 June 2024, 17:31:51 UTC
07d90e5 adding doc string to right_shift updated punctuation and phrasing fix example comment/code ordering reformatting of description and adding print_binary helper moving helper function to ufuncs.py moved print_binary definition to doc string fix in doc print_binary def and other edits 10 June 2024, 16:59:42 UTC
814b32a tree_all: add support for is_leaf 10 June 2024, 16:46:15 UTC
833c7ba Allow generation of sharding strategies with mixed mesh shapes by default. PiperOrigin-RevId: 641930205 10 June 2024, 16:38:39 UTC
990b475 jax.scipy.special.beta: deprecate x,y in favor of a,b 10 June 2024, 16:01:39 UTC
0739d52 [Mosaic GPU] Don't always run with llvm::DebugFlag enabled This slipped past during code review. PiperOrigin-RevId: 641899993 10 June 2024, 14:50:26 UTC
cd93b46 Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions. PiperOrigin-RevId: 641879836 10 June 2024, 13:21:16 UTC
991797a Merge pull request #21765 from hawkinsp:release PiperOrigin-RevId: 641876244 10 June 2024, 13:03:58 UTC
5e7ad60 Removed the double re-exporting of Pallas GPU/TPU APIs jax.experimental.pallas.{gpu,tpu} now import directly from the relevant jax._src.pallas.{triton,mosaic} submodules. PiperOrigin-RevId: 641875127 10 June 2024, 12:59:09 UTC
3b4039c [Mosaic GPU] Load LLVM lowering interfaces for all dialects Apparently we were missing interface registration code for LLVM lowering, which the gpu-to-llvm pass gracefully ignores unless compiled with debug assertions enabled. But, simply adding the assertions in fact makes the pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not what we want. That's why I've replaced it with a minimal version that is only repsponsible for handling the GPU dialect, making the lowering similar to the one prior to extra registrations. PiperOrigin-RevId: 641874183 10 June 2024, 12:55:01 UTC
e071053 Prepare for 0.4.29 release. 10 June 2024, 12:54:28 UTC
2ade7e7 [pallas] Move the hardware_generation query in the code path that needs it This change allows us to lower and export Pallas calls even on machines that do not have TPUs, in many cases. PiperOrigin-RevId: 641841079 10 June 2024, 10:13:36 UTC
af95803 Merge pull request #21759 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 641831969 10 June 2024, 09:29:12 UTC
479145b Add example code snippets to jax.scipy.signal.convolve2d, correlate and correlate2d docs 10 June 2024, 08:57:41 UTC
775c6f8 Fix Typos in docs and one error message 10 June 2024, 06:08:01 UTC
8fbe65b Update XLA dependency to use revision http://github.com/openxla/xla/commit/32ba408c0e2b7ef5f5821c0781601ba17d467076. PiperOrigin-RevId: 641736314 09 June 2024, 22:36:14 UTC
6617a0d Expand `device_put` benchmarks to run with different numbers of arrays and input types For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays. PiperOrigin-RevId: 641716268 09 June 2024, 20:01:51 UTC
a8246ea Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs. For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case. In a future release of JAX, this behavior will become an error. PiperOrigin-RevId: 641690427 09 June 2024, 16:18:29 UTC
14d87d3 [export] Move the export implementation to jax._src.export. This is part of the work to move the export APIs out of jax.experimental. For now, the way to use this implementation is still through `jax.experimental.export`. Had to add a few "#type ignore" to the _export.py because previously the file was exempt from internal pytype. Will try to fix these in a later PR. PiperOrigin-RevId: 641688200 09 June 2024, 15:59:50 UTC
aaa559a Update XLA dependency to use revision http://github.com/openxla/xla/commit/7667c639184a726f32eeeb856bab652f83215e4a. PiperOrigin-RevId: 641568627 08 June 2024, 22:45:36 UTC
b486a95 Merge pull request #21507 from renecotyfanboy:main PiperOrigin-RevId: 641429523 08 June 2024, 03:28:23 UTC
6c822c0 Update XLA dependency to use revision http://github.com/openxla/xla/commit/3195fdc8511c42cf975f07b0002cbd30fc361bde. PiperOrigin-RevId: 641387498 07 June 2024, 23:19:00 UTC
d324040 Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses". PiperOrigin-RevId: 641381432 07 June 2024, 22:52:35 UTC
751d59c increase default precision for hyp1f1 07 June 2024, 22:38:51 UTC
57826d8 Add a no input memories_test and enable memories test on vf 2x2 PiperOrigin-RevId: 641361865 07 June 2024, 21:40:44 UTC
0d047a1 Merge pull request #21718 from jakevdp:pallas-config PiperOrigin-RevId: 641349981 07 June 2024, 20:58:49 UTC
44a13c9 Merge code between `make_jaxpr` and `jit(f).trace`. The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't. Since we can keep the existing behavior and still merge the implementation is a good cleanup! Fixes https://github.com/google/jax/issues/21116 PiperOrigin-RevId: 641347140 07 June 2024, 20:48:31 UTC
25cc84b Merge pull request #21615 from selamw1:append_doc PiperOrigin-RevId: 641344856 07 June 2024, 20:39:57 UTC
dfc6076 Merge pull request #21744 from superbobry:typing PiperOrigin-RevId: 641339815 07 June 2024, 20:23:31 UTC
136289e Added filelock to py_deps This should unblock #21394, which uses filelock in the compilation cache. PiperOrigin-RevId: 641338150 07 June 2024, 20:16:33 UTC
7d913f7 Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix PiperOrigin-RevId: 641325047 07 June 2024, 19:29:31 UTC
0786da8 Removed unnecessary mypy exclusions from pyproject.toml * 2/3 files type check just fine now * the remaining one could be handled via a file-level directive 07 June 2024, 19:07:42 UTC
f4c6437 Merge pull request #21680 from ROCm:ci_spmm PiperOrigin-RevId: 641316410 07 June 2024, 18:57:12 UTC
af90464 Merge pull request #21733 from dfm:ffi-capsule-docstring PiperOrigin-RevId: 641307843 07 June 2024, 18:27:41 UTC
bd499a9 Merge pull request #21690 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 641292860 07 June 2024, 17:38:07 UTC
98d7235 Merge pull request #21501 from jakevdp:softmax-inf-doc PiperOrigin-RevId: 641291919 07 June 2024, 17:34:40 UTC
1459ac0 Merge pull request #21731 from tttc3:cross-product-typo PiperOrigin-RevId: 641285460 07 June 2024, 17:18:35 UTC
2899c9f Merge pull request #21692 from rajasekharporeddy:testbranch2 PiperOrigin-RevId: 641285369 07 June 2024, 17:15:22 UTC
30feb35 Merge pull request #21656 from yamlyeti:yamlyeti-patch-1 PiperOrigin-RevId: 641284969 07 June 2024, 17:12:02 UTC
1fa6659 Edit `pycapsule` docstring to provide a little bit more context The docstring for the recently added `pycapsule` function in `jax.extend.ffi` didn't conform to our usual docstring format, so I updated it and added a little bit more context. 07 June 2024, 17:07:03 UTC
5fcd50b Refactor kernel function assigment PiperOrigin-RevId: 641255192 07 June 2024, 15:20:31 UTC
f51af87 fp8 matmul in pallas PiperOrigin-RevId: 641254832 07 June 2024, 15:17:06 UTC
da8a7b2 Add in the tutorial the idea to test 1 process per node and 1 process per GPU. 07 June 2024, 14:00:04 UTC
3914cb4 [export] Remove old deprecated APIs for jax.experimental.export. See CHANGELOG.md. The deprecation period has passed. Also replace deprecated .call_exported with .call in tests. PiperOrigin-RevId: 641236222 07 June 2024, 13:52:10 UTC
21f71c6 fix typo in `jax.numpy.linalg.cross` docstring 07 June 2024, 12:43:51 UTC
5d6413c Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py PiperOrigin-RevId: 641149152 07 June 2024, 07:01:30 UTC
c01c984 Add missing arguments for jnp.extract's python binding signature. PiperOrigin-RevId: 641121305 07 June 2024, 04:34:38 UTC
6d94ae3 Improve docs for jnp.angle and jnp.flip 07 June 2024, 04:33:07 UTC
6d85c38 Improve documentation for jnp.fliplr and jnp.flipud 07 June 2024, 04:28:02 UTC
625ea07 Merge pull request #21710 from jakevdp:fix-jax2tf PiperOrigin-RevId: 641112498 07 June 2024, 03:45:57 UTC
ea6dfd1 rename `Specialized` to `Traced` (and `specialize` to `trace`) PiperOrigin-RevId: 641076488 07 June 2024, 00:43:08 UTC
dd40d88 Update XLA dependency to use revision http://github.com/openxla/xla/commit/9449b0851c855ae3f17e8675d95d1f1b68104abd. PiperOrigin-RevId: 641069331 07 June 2024, 00:12:57 UTC
a2c31f4 pallas/mosaic test: avoid leaking global config state 06 June 2024, 23:00:02 UTC
a1b5860 Merge pull request #21711 from jakevdp:setup-module PiperOrigin-RevId: 641049524 06 June 2024, 22:59:07 UTC
a861c55 test cleanup: use ExitStack to reduce test boilerplate 06 June 2024, 21:18:27 UTC
d457f9a Merge pull request #21716 from gnecula:exp_rename_sharding PiperOrigin-RevId: 641017765 06 June 2024, 21:17:10 UTC
01ee768 [export] Rename in_shardings and out_shardings fields. We rename `in_shardings` to `in_shardings_hlo` to remove confusion with JAX's use of `in_shardings`. We also rename `xla_compatible_in_sharding` to `in_shardings_jax` since we do not have a XLACompatibleSharding type anymore. 06 June 2024, 21:00:16 UTC
aee62e4 Implement `lower` in terms of `specialize` PiperOrigin-RevId: 641005643 06 June 2024, 20:39:07 UTC
90c83bb Merge pull request #21484 from dfm:custom-call-lowering PiperOrigin-RevId: 640996459 06 June 2024, 20:10:28 UTC
2c246df Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4 PiperOrigin-RevId: 640987745 06 June 2024, 19:41:17 UTC
fbf2a62 Remove `jaxpr` and `name` from `Lowered` because `specialize` already has those. This keeps the abstraction boundary clear. Adapt `export` to use `specialize`. PiperOrigin-RevId: 640968129 06 June 2024, 18:38:56 UTC
a65d3ae [Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops - Support non-zero minor offsets without having to relayout (they're still a no-op). - Remove restriction on tiling which now allows 1D packed types to work. PiperOrigin-RevId: 640967375 06 June 2024, 18:35:19 UTC
48355cd jax2tf_test: ensure no modification of global config 06 June 2024, 18:27:33 UTC
82516c5 Merge pull request #21694 from rajasekharporeddy:doc_typos PiperOrigin-RevId: 640956334 06 June 2024, 18:05:37 UTC
back to top