Raw File
guide.js
'use strict';

var assert = require('assert');
var _ = require('underscore');
var util = require('./util');
var Tensor = require('./tensor');
var ad = require('./ad');
var dists = require('./dists');
var domains = require('./domain');

var T = ad.tensor;

// Returns an independent guide distribution for the given target
// distribution, sample address pair. Guiding all choices with
// independent guide distributions and optimizing the elbo yields
// mean-field variational inference.
function independent(targetDist, sampleAddress, env) {

  // Include the distribution name in the guide parameter name to
  // avoid collisions when the distribution type changes between
  // calls. (As a result of the distribution passed depending on a
  // random choice.)
  var relativeAddress = util.relativizeAddress(env, sampleAddress);
  var baseName = relativeAddress + '$mf$' + targetDist.meta.name + '$';

  var distSpec = spec(targetDist);

  var guideParams = _.mapObject(distSpec.params, function(spec, name) {

    return _.has(spec, 'param') ?
        makeParam(spec.param, name, baseName, env) :
        spec.const;

  });

  return new distSpec.type(guideParams);

}

function makeParam(paramSpec, paramName, baseName, env) {
  var dims = paramSpec.dims; // e.g. [2, 1]
  var domain = paramSpec.domain; // e.g. new RealInterval(0, Infinity)
  var name = baseName + paramName;

  var viParamDim, squish;
  if (domain) {
    var ret = squishFn(domain, dims);
    viParamDim = ret.dimsIn;
    squish = ret.f;
  } else {
    viParamDim = dims;
  }

  var param = registerParam(env, name, viParamDim);

  // Apply squishing.
  if (squish) {
    param = squish(param);
  }

  // Collapse tensor with dims=[1] to scalar.
  if (dims.length === 1 && dims[0] === 1) {
    param = ad.tensor.get(param, 0);
  }

  return param;
}

function registerParam(env, name, dims) {
  return util.registerParams(env, name, function() {
    return [new Tensor(dims)];
  })[0];
}

// This function specifies an appropriate guide distribution for the
// given target distribution. This specification is abstract, given in
// terms of the distribution type, and a description of the parameters
// required to use this type as a guide. It's left to callers to
// generate suitable parameters and instantiate the distribution.

// For example:

// spec(Gaussian({mu: 0, sigma: 1}))
//
// =>
//
// {
//   type: TensorGaussian,
//   params: {
//     mu: {param: {dims: [1]}},
//     sigma: {param: {dims: [1]}},
//     dims: {const: [0, 1]}
//   }
// }

// Note that all parameters described are tensors. If a distribution
// is parameterized by a scalar then the spec includes a tensor with
// dims=[1] for that parameter. It is the responsibility of callers to
// turn this back into a scalar before use.

function spec(targetDist) {
  if (targetDist instanceof dists.Dirichlet) {
    return dirichletSpec(targetDist);
  } else if (targetDist instanceof dists.TensorGaussian) {
    return tensorGaussianSpec(targetDist);
  } else if (targetDist instanceof dists.Uniform) {
    return uniformSpec(targetDist);
  } else if (targetDist instanceof dists.Gamma) {
    return gammaSpec(targetDist);
  } else if (targetDist instanceof dists.Beta) {
    return betaSpec(targetDist);
  } else if (targetDist instanceof dists.Discrete) {
    return discreteSpec(targetDist);
  } else if (targetDist instanceof dists.RandomInteger ||
             targetDist instanceof dists.Binomial ||
             targetDist instanceof dists.MultivariateGaussian) {
    throwAutoGuideError(targetDist);
  } else {
    return defaultSpec(targetDist);
  }
}

function throwAutoGuideError(targetDist) {
  var msg = 'Cannot automatically generate a guide for a ' +
      targetDist.meta.name + ' distribution.';
  throw new Error(msg);
}

// The default is a guide of the same type as the target. We determine
// the dimension of the parameters by looking at the target
// distribution instance, and get information about their domain from
// the distribution meta-data.
function defaultSpec(targetDist) {
  var paramSpec = _.map(targetDist.meta.params, function(paramMeta) {

    var name = paramMeta.name;
    var targetParam = ad.value(targetDist.params[name]);

    var dims;
    if (targetParam instanceof Tensor) {
      dims = targetParam.dims;
    } else if (_.isNumber(targetParam)) {
      dims = [1];
    } else {
      throwAutoGuideError(targetDist);
    }

    return [name, {param: {dims: dims, domain: paramMeta.domain}}];

  });

  return {
    type: targetDist.constructor,
    params: _.object(paramSpec)
  };
}

function dirichletSpec(targetDist) {
  var d = ad.value(targetDist.params.alpha).length - 1;
  return {
    type: dists.LogisticNormal,
    params: {
      mu: {param: {dims: [d, 1]}},
      sigma: {param: {dims: [d, 1], domain: domains.gt(0)}}
    }
  };
}

function tensorGaussianSpec(targetDist) {
  var dims = targetDist.params.dims;
  return {
    type: dists.DiagCovGaussian,
    params: {
      mu: {param: {dims: dims}},
      sigma: {param: {dims: dims, domain: domains.gt(0)}}
    }
  };
}

function uniformSpec(targetDist) {
  return {
    type: dists.LogitNormal,
    params: {
      a: {const: targetDist.params.a},
      b: {const: targetDist.params.b},
      mu: {param: {dims: [1]}},
      sigma: {param: {dims: [1], domain: domains.gt(0)}}
    }
  };
}

function betaSpec(targetDist) {
  return {
    type: dists.LogitNormal,
    params: {
      a: {const: 0},
      b: {const: 1},
      mu: {param: {dims: [1]}},
      sigma: {param: {dims: [1], domain: domains.gt(0)}}
    }
  };
}

function gammaSpec(targetDist) {
  return {
    type: dists.IspNormal,
    params: {
      mu: {param: {dims: [1]}},
      sigma: {param: {dims: [1], domain: domains.gt(0)}}
    }
  };
}

function discreteSpec(targetDist) {
  var d = ad.value(targetDist.params.ps).length;
  return {
    type: dists.Discrete,
    params: {
      ps: {param: {dims: [d, 1], domain: domains.simplex}}
    }
  };
}

function softplus(x) {
  return T.log(T.add(T.exp(x), 1));
}

// Returns a function `f` that maps from tensors of unbounded reals to
// tensors in `domain` of dimension `dimsOut`. The function `f` takes
// a tensor of unbounded reals of dimension `dimsIn`.

// Parameters:
// domain: Output domain
// dimsOut: Output dimension

// Returns:
// dimsIn: Input dimension
// f: Squishing function

function squishFn(domain, dimsOut) {
  if (domain instanceof domains.RealInterval) {
    return {dimsIn: dimsOut, f: squishToInterval(domain)};
  } else if (domain === domains.simplex) {
    if (dimsOut.length !== 2 || dimsOut[1] !== 1) {
      throw new Error('Can only map vectors to the probability simplex.');
    }
    return {dimsIn: [dimsOut[0] - 1, 1], f: dists.squishToProbSimplex};
  } else {
    throw new Error('Unknown domain type.');
  }
}

function squishToInterval(domain) {
  var a = domain.a;
  var b = domain.b;
  if (a === -Infinity) {
    return function(x) {
      var y = softplus(x);
      return T.add(T.neg(y), b);
    };
  } else if (b === Infinity) {
    return function(x) {
      var y = softplus(x);
      return T.add(y, a);
    };
  } else {
    return function(x) {
      var y = T.sigmoid(x);
      return T.add(T.mul(y, b - a), a);
    };
  }
}

module.exports = {
  independent: independent
};
back to top