#' @export makeRLearner.classif.ctree = function() { makeRLearnerClassif( cl = "classif.ctree", package = "party", par.set = makeParamSet( makeDiscreteLearnerParam(id = "teststat", default = "quad", values = c("quad", "max")), makeDiscreteLearnerParam(id = "testtype", default = "Bonferroni", values = c("Bonferroni", "MonteCarlo", "Univariate", "Teststatistic")), makeNumericLearnerParam(id = "mincriterion", default = 0.95, lower = 0, upper = 1), makeIntegerLearnerParam(id = "minsplit", default = 20L, lower = 1L), makeIntegerLearnerParam(id = "minbucket", default = 7L, lower = 1L), makeLogicalLearnerParam(id = "stump", default = FALSE), makeIntegerLearnerParam(id = "nresample", default = 9999L, lower = 1L, requires = quote(testtype == "MonteCarlo")), makeIntegerLearnerParam(id = "maxsurrogate", default = 0L, lower = 0L), makeIntegerLearnerParam(id = "mtry", default = 0L, lower = 0L), makeLogicalLearnerParam(id = "savesplitstats", default = TRUE, tunable = FALSE), makeIntegerLearnerParam(id = "maxdepth", default = 0L, lower = 0L) ), properties = c("twoclass", "multiclass", "missings", "numerics", "factors", "ordered", "prob", "weights"), name = "Conditional Inference Trees", short.name = "ctree", note = "See `?ctree_control` for possible breakage for nominal features with missingness.", callees = c("ctree", "ctree_control") ) } #' @export trainLearner.classif.ctree = function(.learner, .task, .subset, .weights = NULL, teststat, testtype, mincriterion, minsplit, minbucket, stump, nresample, maxsurrogate, mtry, savesplitstats, maxdepth, ...) { ctrl = learnerArgsToControl(party::ctree_control, teststat, testtype, mincriterion, minsplit, minbucket, stump, nresample, maxsurrogate, mtry, savesplitstats, maxdepth) f = getTaskFormula(.task) party::ctree(f, data = getTaskData(.task, .subset), controls = ctrl, weights = .weights, ...) } #' @export predictLearner.classif.ctree = function(.learner, .model, .newdata, ...) { if (.learner$predict.type == "prob") { m = .model$learner.model p = party::treeresponse(m, newdata = .newdata, ...) p = do.call(rbind, p) rownames(p) = NULL colnames(p) = m@responses@levels[[.model$task.desc$target]] return(p) } else { predict(.model$learner.model, newdata = .newdata, ...) } }