Raw File
semiMarkovRandomWalkConstrained.wppl
var observe = function(pos) {
  return map(
      function(x) { return gaussian(x, 5); },
      pos
  );
};

var init = function(dim) {
  var state1 = repeat(dim, function() { return gaussian(200, 1) });
  var state2 = repeat(dim, function() { return gaussian(200, 1) });
  var states = [state1, state2];
  var observations = map(observe, states);
  return {
    states: states,
    observations: observations
  }
}

var semiMarkovWalk = function(n, dim) {
  var prevData = (n == 2) ? init(dim) : semiMarkovWalk(n - 1, dim);
  var prevStates = prevData.states;
  var prevObservations = prevData.observations;
  var newState = transition(last(prevStates), secondLast(prevStates));
  var newObservation = observe(newState);
  return {
    states: prevStates.concat([newState]),
    observations: prevObservations.concat([newObservation])
  }
};

var transition = function(lastPos, secondLastPos) {
  return map2(
      function(lastX, secondLastX) {
        var momentum = (lastX - secondLastX) * .7;
        return gaussian(lastX + momentum, 3);
      },
      lastPos,
      secondLastPos
  );
};


var observeConstrained = function(pos, trueObs) {
  return map2(
      function(x, obs) { factor(Gaussian({mu: x, sigma: 5}).score(obs)); },
      pos,
      trueObs
  );
};

var initConstrained = function(dim, trueObs) {
  var state1 = repeat(dim, function() { return gaussian(200, 1) });
  var obs1 = observeConstrained(state1, trueObs[0]);
  var state2 = repeat(dim, function() { return gaussian(200, 1) });
  var obs2 = observeConstrained(state2, trueObs[1]);
  return {
    states: [state1, state2],
    observations: [obs1, obs2]
  }
}

var semiMarkovWalkConstrained = function(n, dim, trueObs) {
  var prevData = (
      (n == 2) ?
      initConstrained(dim, trueObs.slice(0, 2)) :
      semiMarkovWalkConstrained(n - 1, dim, trueObs.slice(0, trueObs.length - 1)));
  var prevStates = prevData.states;
  var prevObservations = prevData.observations;
  var newState = transition(last(prevStates), secondLast(prevStates));
  var newObservation = observeConstrained(newState, trueObs[trueObs.length - 1]);
  return {
    states: prevStates.concat([newState]),
    observations: prevObservations.concat([newObservation])
  }
};


// Run model using particle filter

var numSteps = 80;
var trueObservations = semiMarkovWalk(numSteps, 2).observations;

var posteriorSampler = Infer({method: 'SMC', particles: 10}, function() {
  return semiMarkovWalkConstrained(numSteps, 2, trueObservations);
}) // Try reducing the number of samples to 1!

var inferredStates = sample(posteriorSampler).states;

inferredStates
back to top