ldaCollapsed.wppl
// Dirichlet-discrete with sample, observe & getCount (could go in webppl header) // NEW
var makeDirichletDiscrete = function(pseudocounts) {
var addCount = function(a, i, j) {
var j = j == undefined ? 0 : j;
if (a.length == 0) {
return [];
} else {
return [a[0] + (i == j)].concat(addCount(a.slice(1), i, j + 1));
}
};
globalStore.DDindex = 1 + (globalStore.DDindex == undefined ? 0 : globalStore.DDindex);
var ddname = 'DD' + globalStore.DDindex;
globalStore[ddname] = pseudocounts;
var ddSample = function() {
var pc = globalStore[ddname]; // get current sufficient stats
var val = sample(discreteERP, [pc]); // sample from predictive. (doesn't need to be normalized.)
globalStore[ddname] = addCount(pc, val); // update sufficient stats
return val;
};
var ddObserve = function(val) {
var pc = globalStore[ddname]; // get current sufficient stats
factor(discreteERP.score([util.normalizeArray(pc)], val));
// score based on predictive distribution (normalize counts)
globalStore[ddname] = addCount(pc, val); // update sufficient stats
};
var ddCounts = function() {
return globalStore[ddname];
};
return {
'sample': ddSample,
'observe': ddObserve,
'getCounts': ddCounts
};
};
var dirichletDiscreteFactor = function(vs, dd, v) { // NEW
var i = indexOf(v, vs);
var observe = dd.observe;
observe(i);
}
// Parameters
var vocabulary = ['bear', 'wolf', 'python', 'prolog'];
var topics = {
'topic1': null,
'topic2': null
};
var docs = {
'doc1': 'bear wolf bear wolf bear wolf python wolf bear wolf'.split(' '),
'doc2': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
'doc3': 'bear wolf bear wolf bear wolf bear wolf bear wolf'.split(' '),
'doc4': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
'doc5': 'bear wolf bear python bear wolf bear wolf bear wolf'.split(' ')
};
// Constants and helper functions
var ones = function(n) {
return repeat(n, function() {return 1.0;});
}
var mapObject = function(fn, obj) {
return _.object(
map(
function(kv) {
return [kv[0], fn(kv[0], kv[1])]
},
_.pairs(obj))
);
}
// Model
var makeWordDist = function() {
return makeDirichletDiscrete(ones(vocabulary.length)); // NEW
};
var makeTopicDist = function() {
return makeDirichletDiscrete(ones(_.size(topics))); // NEW
};
var model = function() {
var wordDistForTopic = mapObject(makeWordDist, topics);
var topicDistForDoc = mapObject(makeTopicDist, docs);
var makeTopicForWord = function(docName, word) {
var sampleDD = topicDistForDoc[docName].sample; // NEW
var i = sampleDD(); // NEW
return _.keys(topics)[i];
};
var makeWordTopics = function(docName, words) {
return map(function(word) {return makeTopicForWord(docName, word);},
words);
};
var topicsForDoc = mapObject(makeWordTopics, docs);
mapObject(
function(docName, words) {
map2(
function(topic, word) {
dirichletDiscreteFactor(vocabulary, wordDistForTopic[topic], word); // NEW
},
topicsForDoc[docName],
words);
},
docs);
// Print out pseudecounts of (dirichlet-discrete) word distributions for topics // NEW
var getCounts1 = wordDistForTopic['topic1'].getCounts;
var getCounts2 = wordDistForTopic['topic2'].getCounts;
var counts = [getCounts1(), getCounts2()];
// console.log(counts);
return counts;
};
MH(model, 5000)