Raw File
ldaCollapsed.wppl
// Dirichlet-discrete with sample, observe & getCount (could go in webppl header) // NEW

var makeDirichletDiscrete = function(pseudocounts) {
  var addCount = function(a, i, j) {
    var j = j == undefined ? 0 : j;
    if (a.length == 0) {
      return [];
    } else {
      return [a[0] + (i == j)].concat(addCount(a.slice(1), i, j + 1));
    }
  };
  globalStore.DDindex = 1 + (globalStore.DDindex == undefined ? 0 : globalStore.DDindex);
  var ddname = 'DD' + globalStore.DDindex;
  globalStore[ddname] = pseudocounts;
  var ddSample = function() {
    var pc = globalStore[ddname];  // get current sufficient stats
    var val = sample(discreteERP, [pc]);  // sample from predictive. (doesn't need to be normalized.)
    globalStore[ddname] = addCount(pc, val); // update sufficient stats
    return val;
  };
  var ddObserve = function(val) {
    var pc = globalStore[ddname];  // get current sufficient stats
    factor(discreteERP.score([util.normalizeArray(pc)], val));
    // score based on predictive distribution (normalize counts)
    globalStore[ddname] = addCount(pc, val); // update sufficient stats
  };
  var ddCounts = function() {
    return globalStore[ddname];
  };
  return {
    'sample': ddSample,
    'observe': ddObserve,
    'getCounts': ddCounts
  };
};

var dirichletDiscreteFactor = function(vs, dd, v) { // NEW
  var i = indexOf(v, vs);
  var observe = dd.observe;
  observe(i);
}


// Parameters

var vocabulary = ['bear', 'wolf', 'python', 'prolog'];

var topics = {
  'topic1': null,
  'topic2': null
};

var docs = {
  'doc1': 'bear wolf bear wolf bear wolf python wolf bear wolf'.split(' '),
  'doc2': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
  'doc3': 'bear wolf bear wolf bear wolf bear wolf bear wolf'.split(' '),
  'doc4': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
  'doc5': 'bear wolf bear python bear wolf bear wolf bear wolf'.split(' ')
};


// Constants and helper functions

var ones = function(n) {
  return repeat(n, function() {return 1.0;});
}

var mapObject = function(fn, obj) {
  return _.object(
      map(
      function(kv) {
        return [kv[0], fn(kv[0], kv[1])]
      },
      _.pairs(obj))
  );
}


// Model

var makeWordDist = function() {
  return makeDirichletDiscrete(ones(vocabulary.length)); // NEW
};

var makeTopicDist = function() {
  return makeDirichletDiscrete(ones(_.size(topics))); // NEW
};

var model = function() {

  var wordDistForTopic = mapObject(makeWordDist, topics);
  var topicDistForDoc = mapObject(makeTopicDist, docs);
  var makeTopicForWord = function(docName, word) {
    var sampleDD = topicDistForDoc[docName].sample; // NEW
    var i = sampleDD(); // NEW
    return _.keys(topics)[i];
  };
  var makeWordTopics = function(docName, words) {
    return map(function(word) {return makeTopicForWord(docName, word);},
               words);
  };
  var topicsForDoc = mapObject(makeWordTopics, docs);

  mapObject(
      function(docName, words) {
        map2(
            function(topic, word) {
              dirichletDiscreteFactor(vocabulary, wordDistForTopic[topic], word); // NEW
            },
            topicsForDoc[docName],
            words);
      },
      docs);

  // Print out pseudecounts of (dirichlet-discrete) word distributions for topics // NEW
  var getCounts1 = wordDistForTopic['topic1'].getCounts;
  var getCounts2 = wordDistForTopic['topic2'].getCounts;
  var counts = [getCounts1(), getCounts2()];
  // console.log(counts);
  return counts;
};

MH(model, 5000)
back to top