Raw File
headerUtils.js
'use strict';

var _ = require('lodash');
var serialize = require('./util').serialize
var Tensor = require('./tensor');
var LRU = require('lru-cache');
var ad = require('./ad');
var assert = require('assert');
var runThunk = require('./guide').runThunk;

module.exports = function(env) {

  function display(s, k, a, x) {
    return k(s, console.log(ad.valueRec(x)));
  }

  var dp = {};

  // Caching for a wppl function f.
  //
  // Caution: if f isn't deterministic weird stuff can happen, since
  // caching is across all uses of f, even in different execution
  // paths.
  dp.cache = function(f, maxSize) {
    var c = LRU(maxSize);
    var cf = function(s, k, a) {
      var args = Array.prototype.slice.call(arguments, 3);

      if (_.some(args, ad.isLifted)) {
        throw new Error('Cannot cache functions of scalar/tensor arguments ' +
                        'when performing automatic differentiation.');
      }

      var stringedArgs = serialize(args);
      if (c.has(stringedArgs)) {
        return k(s, c.get(stringedArgs));
      } else {
        var newk = function(s, r) {
          if (c.has(stringedArgs)) {
            // This can happen when cache is used on recursive functions
            console.log('Already in cache:', stringedArgs);
            if (serialize(c.get(stringedArgs)) !== serialize(r)) {
              console.log('OLD AND NEW CACHE VALUE DIFFER!');
              console.log('Old value:', c.get(stringedArgs));
              console.log('New value:', r);
            }
          }
          c.set(stringedArgs, r);
          if (!maxSize && c.length === 1e4) {
            console.log(c.length + ' function calls have been cached.');
            console.log('The size of the cache can be limited by calling cache(f, maxSize).');
          }
          return k(s, r);
        };
        return f.apply(this, [s, newk, a].concat(args));
      }
    };
    // Make the cache publicly available to facilitate checking the
    // complexity of algorithms.
    cf.cache = c;
    return cf;
  };

  function cache(s, k, a, f, maxSize) {
    return k(s, dp.cache(f, maxSize));
  }

  function apply(s, k, a, wpplFn, args) {
    return wpplFn.apply(global, [s, k, a].concat(args));
  }

  function notAllowed(fn, name) {
    return function() {
      throw new Error(fn + ' is not allowed in ' + name + '.');
    };
  }

  function makeDeterministicCoroutine(name) {
    return {
      sample: notAllowed('sample', name),
      factor: notAllowed('factor', name),
      incrementalize: env.defaultCoroutine.incrementalize,
      oldCoroutine: env.coroutine
    };
  }

  // Applies a deterministic function. Attempts by wpplFn to call
  // sample or factor generate an error.
  function applyd(s, k, a, wpplFn, args, name) {
    var coroutine = env.coroutine;
    env.coroutine = makeDeterministicCoroutine(name);
    return apply(s, function(s, val) {
      env.coroutine = coroutine;
      return k(s, val);
    }, a, wpplFn, args);
  }

  // Annotating a function object with its lexical id and
  //    a list of its free variable values.
  var __uniqueid = 0;
  var _Fn = {
    tag: function(fn, lexid, freevarvals) {
      fn.__lexid = lexid;
      fn.__uniqueid = __uniqueid++;
      fn.__freeVarVals = freevarvals;
      return fn;
    }
  };

  // Called from compiled code to save the current address in the
  // container `obj`.
  var _addr = {
    save: function(obj, address) {
      obj.value = address;
    }
  };

  var zeros = function(s, k, a, dims) {
    return k(s, new Tensor(dims));
  };

  var ones = function(s, k, a, dims) {
    return k(s, new Tensor(dims).fill(1));
  };

  // It is the responsibility of individual coroutines to implement
  // data sub-sampling and to make use of the conditional independence
  // information mapData provides. To do so, coroutines can implement
  // one or more of the following methods:

  // mapDataFetch: Called when mapData is entered, providing an
  // opportunity to perform book-keeping etc. The method should return
  // an object with data, ix and (optional) address properties.

  //   data: The array that will be mapped over.

  //   ix: An array of integers of the same length as data, where each
  //   entry indicates the position at which the corresponding entry
  //   in data can be found in the original data array. This is used
  //   to ensure that corresponding data items and stack addresses are
  //   used when applying the observation function. For convenience,
  //   null can be returned as a short hand for _.range(data.length).

  //   address: When present, mapData behaves as though it was called
  //   from this address.

  // mapDataEnter/mapDataLeave: Called before/after every application
  // of the observation function.

  // mapDataFinal: Called once all data have been mapped over.

  // When the current coroutine doesn't provide specific handling the
  // behavior is equivalent to regular `map`.

  // This is still somewhat experimental. The interface may change in
  // the future.

  function mapData(s, k, a, opts, obsFn) {
    opts = opts || {};

    var data = opts.data;
    if (!_.isArray(data)) {
      throw new Error('mapData: No data given.');
    }

    var ret = env.coroutine.mapDataFetch ?
        env.coroutine.mapDataFetch(data, opts, a) :
        {data: data, ix: null};

    var ix = ret.ix;
    var finalData = ret.data;
    var address = ret.address || a;

    assert.ok(ix === null ||
              (_.isArray(ix) && (ix.length === finalData.length)),
              'Unexpected value returned by mapDataFetch.');

    // We return undefined when sub-sampling data etc.
    var doReturn = finalData === data;

    return cpsMapData(s, function(s, v) {
      if (env.coroutine.mapDataFinal) {
        env.coroutine.mapDataFinal(a);
      }
      return k(s, doReturn ? v : undefined);
    }, address, finalData, ix, obsFn);
  }

  function cpsMapData(s, k, a, data, indices, f, acc, i) {
    i = (i === undefined) ? 0 : i;
    acc = (acc === undefined) ? [] : acc;
    var length = (indices === null) ? data.length : indices.length;
    if (i === length) {
      return k(s, acc);
    } else {
      var ix = (indices === null) ? i : indices[i];
      if (env.coroutine.mapDataEnter) {
        env.coroutine.mapDataEnter();
      }
      return f(s, function(s, v) {
        if (env.coroutine.mapDataLeave) {
          env.coroutine.mapDataLeave();
        }

        return function() {
          return cpsMapData(s, k, a, data, indices, f, acc.concat([v]), i + 1);
        };
      }, a.concat('_$$' + ix), data[i], ix);
    }
  }

  function guide(s, k, a, thunk) {
    if (env.coroutine.guideRequired) {
      return runThunk(thunk, env, s, a, function(s2, val) {
        return k(s2);
      });
    } else {
      return k(s);
    }
  }

  return {
    display: display,
    cache: cache,
    dp: dp,
    apply: apply,
    applyd: applyd,
    _Fn: _Fn,
    _addr: _addr,
    zeros: zeros,
    ones: ones,
    mapData: mapData,
    guide: guide
  };

};
back to top