https://github.com/google/jax
Raw File
Tip revision: 838bc454895ed2086563301936fb0d6d852fd198 authored by jax authors on 25 January 2023, 01:48:19 UTC
Merge pull request #14148 from skye:version
Tip revision: 838bc45
core.py
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

from __future__ import annotations

from jax._src.core import (
  AbstractToken as AbstractToken,
  AbstractValue as AbstractValue,
  Atom as Atom,
  AvalMapHandlerPair as AvalMapHandlerPair,
  AxisEnvFrame as AxisEnvFrame,
  AxisName as AxisName,
  AxisPrimitive as AxisPrimitive,
  AxisSize as AxisSize,
  AxisSubst as AxisSubst,
  Bot as Bot,
  CallPrimitive as CallPrimitive,
  ClosedCallPrimitive as ClosedCallPrimitive,
  ClosedJaxpr as ClosedJaxpr,
  ConcreteArray as ConcreteArray,
  ConcretizationTypeError as ConcretizationTypeError,
  CustomPpEqnRule as CustomPpEqnRule,
  DArray as DArray,
  DArrayDimHandler as DArrayDimHandler,
  DBIdx as DBIdx,
  DConcreteArray as DConcreteArray,
  DShapedArray as DShapedArray,
  DimSize as DimSize,
  DimensionHandler as DimensionHandler,
  DropVar as DropVar,
  DuplicateAxisNameError as DuplicateAxisNameError,
  Effect as Effect,
  Effects as Effects,
  EvalTrace as EvalTrace,
  FLAGS as FLAGS,
  HashableFunction as HashableFunction,
  HashableWrapper as HashableWrapper,
  InDBIdx as InDBIdx,
  InconclusiveDimensionOperation as InconclusiveDimensionOperation,
  InputType as InputType,
  Jaxpr as Jaxpr,
  JaxprEqn as JaxprEqn,
  JaxprPpContext as JaxprPpContext,
  JaxprPpSettings as JaxprPpSettings,
  JaxprTypeError as JaxprTypeError,
  Literal as Literal,
  MainTrace as MainTrace,
  MapPrimitive as MapPrimitive,
  NameGatheringSubst as NameGatheringSubst,
  NamedShape as NamedShape,
  OpaqueDType as OpaqueDType,
  OutDBIdx as OutDBIdx,
  OutputType as OutputType,
  ParamDict as ParamDict,
  Primitive as Primitive,
  Shape as Shape,
  ShapedArray as ShapedArray,
  Sublevel as Sublevel,
  TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
  ThreadLocalState as ThreadLocalState,
  Token as Token,
  Trace as Trace,
  TraceStack as TraceStack,
  TraceState as TraceState,
  Tracer as Tracer,
  TracerArrayConversionError as TracerArrayConversionError,
  TracerIntegerConversionError as TracerIntegerConversionError,
  UnexpectedTracerError as UnexpectedTracerError,
  UnshapedArray as UnshapedArray,
  Value as Value,
  Var as Var,
  _SPECIAL_DIMENSION_HANDLERS as _SPECIAL_DIMENSION_HANDLERS,
  _TempAxisName as _TempAxisName,
  _canonicalize_dimension as _canonicalize_dimension,
  _check_call as _check_call,
  _check_closed_call as _check_closed_call,
  _check_jaxpr as _check_jaxpr,
  _check_map as _check_map,
  _compact_eqn_should_include as _compact_eqn_should_include,
  _dim_handler_and_canonical as _dim_handler_and_canonical,
  _dimension_handler_int as _dimension_handler_int,
  _dtype_object as _dtype_object,
  _effect_free_abstract_eval as _effect_free_abstract_eval,
  _encode_digits_alphabetic as _encode_digits_alphabetic,
  _forward_to_value as _forward_to_value,
  _get_special_dim_handler as _get_special_dim_handler,
  _initialize_jax_jit_thread_local_state as _initialize_jax_jit_thread_local_state,
  _invalid_shape_error as _invalid_shape_error,
  _jaxpr_type_to_callable_annotation as _jaxpr_type_to_callable_annotation,
  _jaxpr_vars as _jaxpr_vars,
  _map_dshaped_array as _map_dshaped_array,
  _map_shaped_array as _map_shaped_array,
  _param_uses_outfeed as _param_uses_outfeed,
  _short_dtype_name as _short_dtype_name,
  _unmap_dshaped_array as _unmap_dshaped_array,
  _unmap_shaped_array as _unmap_shaped_array,
  _update_thread_local_jit_state as _update_thread_local_jit_state,
  _why_alive as _why_alive,
  _why_alive_container_info as _why_alive_container_info,
  abstract_token as abstract_token,
  annotations as annotations,
  apply_todos as apply_todos,
  as_hashable_function as as_hashable_function,
  as_named_shape as as_named_shape,
  attrgetter as attrgetter,
  aval_mapping_handlers as aval_mapping_handlers,
  aval_method as aval_method,
  aval_property as aval_property,
  axis_frame as axis_frame,
  axis_substitution_rules as axis_substitution_rules,
  bint as bint,
  bot as bot,
  call as call,
  call_bind_with_continuation as call_bind_with_continuation,
  call_impl as call_impl,
  call_p as call_p,
  canonicalize_dim as canonicalize_dim,
  canonicalize_shape as canonicalize_shape,
  check_eqn as check_eqn,
  check_jaxpr as check_jaxpr,
  check_type as check_type,
  check_valid_jaxtype as check_valid_jaxtype,
  closed_call_p as closed_call_p,
  collections as collections,
  concrete_aval as concrete_aval,
  concrete_or_error as concrete_or_error,
  concretization_function_error as concretization_function_error,
  config as config,
  contextmanager as contextmanager,
  cur_sublevel as cur_sublevel,
  curry as curry,
  custom_typechecks as custom_typechecks,
  dataclass as dataclass,
  dedup_referents as dedup_referents,
  diff_dim as diff_dim,
  diff_shape as diff_shape,
  dilate_dim as dilate_dim,
  dilate_shape as dilate_shape,
  dimension_as_value as dimension_as_value,
  divide_shape_sizes as divide_shape_sizes,
  do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
  dtypes as dtypes,
  ensure_compile_time_eval as ensure_compile_time_eval,
  escaped_tracer_error as escaped_tracer_error,
  eval_context as eval_context,
  eval_jaxpr as eval_jaxpr,
  extend_axis_env as extend_axis_env,
  extend_axis_env_nd as extend_axis_env_nd,
  find_top_trace as find_top_trace,
  full_lower as full_lower,
  functools as functools,
  gc as gc,
  gensym as gensym,
  get_aval as get_aval,
  get_referent as get_referent,
  greater_equal_dim as greater_equal_dim,
  greater_equal_shape as greater_equal_shape,
  has_opaque_dtype as has_opaque_dtype,
  inspect as inspect,
  is_constant_dim as is_constant_dim,
  is_constant_shape as is_constant_shape,
  is_dim as is_dim,
  is_empty_shape as is_empty_shape,
  is_opaque_dtype as is_opaque_dtype,
  is_special_dim_size as is_special_dim_size,
  it as it,
  jax_config as jax_config,
  jax_jit as jax_jit,
  jaxpr_as_fun as jaxpr_as_fun,
  jaxpr_uses_outfeed as jaxpr_uses_outfeed,
  jaxprs_in_params as jaxprs_in_params,
  join_effects as join_effects,
  join_named_shapes as join_named_shapes,
  lattice_join as lattice_join,
  leaked_tracer_error as leaked_tracer_error,
  literalable_types as literalable_types,
  lu as lu,
  map as map,
  map_bind as map_bind,
  map_bind_with_continuation as map_bind_with_continuation,
  mapped_aval as mapped_aval,
  maybe_find_leaked_tracers as maybe_find_leaked_tracers,
  namedtuple as namedtuple,
  new_base_main as new_base_main,
  new_jaxpr_eqn as new_jaxpr_eqn,
  new_main as new_main,
  new_sublevel as new_sublevel,
  no_axis_name as no_axis_name,
  no_effects as no_effects,
  np as np,
  opaque_dtypes as opaque_dtypes,
  operator as operator,
  ordered_effects as ordered_effects,
  outfeed_primitives as outfeed_primitives,
  partial as partial,
  partialmethod as partialmethod,
  pp as pp,
  pp_aval as pp_aval,
  pp_eqn as pp_eqn,
  pp_eqn_rules as pp_eqn_rules,
  pp_eqns as pp_eqns,
  pp_jaxpr as pp_jaxpr,
  pp_jaxpr_eqn_range as pp_jaxpr_eqn_range,
  pp_jaxpr_skeleton as pp_jaxpr_skeleton,
  pp_jaxprs as pp_jaxprs,
  pp_kv_pair as pp_kv_pair,
  pp_kv_pairs as pp_kv_pairs,
  pp_var as pp_var,
  pp_vars as pp_vars,
  primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
  primitive_uses_outfeed as primitive_uses_outfeed,
  process_env_traces_call as process_env_traces_call,
  process_env_traces_map as process_env_traces_map,
  prod as prod,
  pytype_aval_mappings as pytype_aval_mappings,
  raise_as_much_as_possible as raise_as_much_as_possible,
  raise_to_shaped as raise_to_shaped,
  raise_to_shaped_mappings as raise_to_shaped_mappings,
  ref as ref,
  reset_trace_state as reset_trace_state,
  safe_map as safe_map,
  safe_zip as safe_zip,
  same_referent as same_referent,
  same_shape_sizes as same_shape_sizes,
  source_info_util as source_info_util,
  stash_axis_env as stash_axis_env,
  str_eqn_compact as str_eqn_compact,
  stride_dim as stride_dim,
  stride_shape as stride_shape,
  subjaxprs as subjaxprs,
  subst_axis_names as subst_axis_names,
  subst_axis_names_eqn as subst_axis_names_eqn,
  subst_axis_names_jaxpr as subst_axis_names_jaxpr,
  subst_axis_names_var as subst_axis_names_var,
  substitute_vars_in_output_ty as substitute_vars_in_output_ty,
  sum_dim as sum_dim,
  sum_shapes as sum_shapes,
  symbolic_equal_dim as symbolic_equal_dim,
  symbolic_equal_one_of_dim as symbolic_equal_one_of_dim,
  symbolic_equal_shape as symbolic_equal_shape,
  thread_local_state as thread_local_state,
  threading as threading,
  token as token,
  total_ordering as total_ordering,
  trace_state_clean as trace_state_clean,
  traceback_util as traceback_util,
  traverse_jaxpr_params as traverse_jaxpr_params,
  tuple_delete as tuple_delete,
  tuple_insert as tuple_insert,
  typecheck as typecheck,
  typecompat as typecompat,
  typematch as typematch,
  types as types,
  typing as typing,
  unmapped_aval as unmapped_aval,
  unsafe_map as unsafe_map,
  unsafe_zip as unsafe_zip,
  used_axis_names as used_axis_names,
  used_axis_names_jaxpr as used_axis_names_jaxpr,
  valid_jaxtype as valid_jaxtype,
  warnings as warnings,
  weakref_lru_cache as weakref_lru_cache,
  zip as zip,
)

from typing import Any, Callable, Dict, Type

# TODO(mattjj,frostig): remove these stubs (pytype workaround)
extract_call_jaxpr: Callable
eval_jaxpr_eqn: Callable
initial_to_final_param_rules: Dict
unit: Any
abstract_unit: AbstractValue
unitvar: Var
UnitVar: Type
back to top