https://github.com/google/jax
Raw File
Tip revision: f282c251d427abdf6f785f4fff5bb9034a2fd083 authored by Jake VanderPlas on 29 March 2023, 17:04:34 UTC
Add minimal pyproject.toml specifying build system
Tip revision: f282c25
jax.nn.initializers.rst
``jax.nn.initializers`` module
==============================

.. currentmodule:: jax.nn.initializers

.. automodule:: jax.nn.initializers


Initializers
------------

This module provides common neural network layer initializers,
consistent with definitions used in Keras and Sonnet.

An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
key used when generating random numbers to initialize the array.

.. autosummary::
  :toctree: _autosummary

    constant
    delta_orthogonal
    glorot_normal
    glorot_uniform
    he_normal
    he_uniform
    lecun_normal
    lecun_uniform
    normal
    ones
    orthogonal
    uniform
    variance_scaling
    zeros
back to top