swh:1:snp:ffdd0a7d2c8ea15ad41d45b3b178f668bd942287
Raw File
Tip revision: 65dce028095b06ee5ed19673e3045548876c41d3 authored by Derek Young on 30 December 2007, 00:00:00 UTC
version 0.3.3
Tip revision: 65dce02
multmixEM.r
multmixEM <- function (y, lambda = NULL, theta = NULL, k = 2, maxit = 10000, 
    epsilon = 1e-08, verb = FALSE) {
  n <- nrow(y)
  p <- ncol(y)
  m <- colSums(y)
  r <- rowSums(y)
  if (any(r!=r[1])) stop("All row sums of y must be equal")  
  mncoeffs <- lgamma(1+r[1]) - rowSums(lgamma(1+y))

  tmp <- multmix.init(y=y, lambda=lambda, theta=theta, k=k)
  lambda <- tmp$lambda
  theta <- tmp$theta
  k <- tmp$k
  restarts<-0
  mustrestart <- FALSE

  while (restarts < 15) {  
    oldll <- -Inf  
    num <- exp(y %*% t(log(pmax(theta,1e-100))) + outer(mncoeffs,log(lambda),"+"))
    dens <- rowSums(num)
    postz <- num/dens
    newll <- sum(log(dens))
    diff <- newll - oldll
    ll <- NULL
    iter <- 0
    while ((iter < maxit) && diff > epsilon) {
      iter <- iter + 1
      oldll <- newll
      ll <- c(ll, oldll)
      zty <- t(postz) %*% y
      theta <- zty/apply(zty, 1, sum)
      theta <- cbind(theta[,1:(p-1)],1-apply(as.matrix(theta[,1:(p-1)]),1,sum))
      lambda <- apply(postz, 2, mean)
      
      num <- exp(y %*% t(log(pmax(theta,1e-100))) + outer(mncoeffs,log(lambda),"+"))
      dens <- rowSums(num)
      postz <- num/dens
      newll <- sum(log(dens))
      diff <- newll - oldll      
      if (diff<0 || is.na(newll) || is.infinite(newll) || is.nan(newll)) {
        mustrestart <- TRUE
        break
      }      
      if (verb) {
        cat("iteration=", iter, "diff=", diff, "log-likelihood", 
            ll[iter], "lambda", lambda, "\n") 
      }
    }
    if (mustrestart) {
      cat("Restarting due to numerical problem.\n")
      mustrestart <- FALSE
      restarts <- restarts + 1
      tmp <- multmix.init(y=y, k=k)
      lambda <- tmp$lambda
      theta <- tmp$theta
      k <- tmp$k
    } else {
      if (iter == maxit) {
        cat("Warning: Not convergent after", maxit, "iterations\n")
      }
      theta[,p] <- 1-apply(as.matrix(theta[,1:(p-1)]),1,sum)
      colnames(theta) <- c(paste("theta", ".", 1:p, sep = ""))
      rownames(theta) <- c(paste("comp", ".", 1:k, sep = ""))
      colnames(postz) <- c(paste("comp", ".", 1:k, sep = ""))
      cat("number of iterations=", iter, "\n")
      return(list(y=y, lambda = lambda, 
                  theta = theta, loglik = sum(log(dens)), posterior = postz, 
                  all.loglik=ll, restarts=restarts, ft="multmixEM"))
    }
  }
  stop("Too many restarts!")
}
back to top