swh:1:snp:af87cd67498ef4fe47c76ed3e7caffe5b61facaf
Raw File
Tip revision: 34a07b26a7abbc1a9628603ab01596e6a511ffba authored by Fons Rademakers on 07 October 2009, 22:25:29 UTC
tag patch release v5-24-00a.
Tip revision: 34a07b2
CostComplexityPruneTool.h
/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : TMVA::DecisionTree                                                    *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Implementation of a Decision Tree                                         *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *      U. of Victoria, Canada                                                    *
 *      MPI-K Heidelberg, Germany                                                 *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://mva.sourceforge.net/license.txt)                                       *
 *                                                                                *
 **********************************************************************************/

#ifndef ROOT_TMVA_CostComplexityPruneTool
#define ROOT_TMVA_CostComplexityPruneTool

////////////////////////////////////////////////////////////////////////////////////////////////////////////
// CostComplexityPruneTool - a class to prune a decision tree using the Cost Complexity method            //
// (see "Classification and Regression Trees" by Leo Breiman et al)                                       //
//                                                                                                        //
// Some definitions:                                                                                      //
//                                                                                                        //
// T_max - the initial, usually highly overtrained tree, that is to be pruned back                        //
// R(T) - quality index (Gini, misclassification rate, or other) of a tree T                              //
// ~T - set of terminal nodes in T                                                                        //
// T' - the pruned subtree of T_max that has the best quality index R(T')                                 //
// alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha*|~T|)       //
//                                                                                                        //
// There are two running modes in CostComplexityPruneTool: (i) one may select a prune strength and prune  //
// the tree T_max until the criterion                                                                     //
//             R(T) - R(t)                                                                                //
//  alpha <    ----------                                                                                 //
//             |~T_t| - 1                                                                                 //
//                                                                                                        //
// is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points              //
// alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned    //
// subtree, defined to be the subtree with the best quality index for the validation sample.              //
////////////////////////////////////////////////////////////////////////////////////////////////////////////

#ifndef ROOT_TMVA_SeparationBase
#include "TMVA/SeparationBase.h"
#endif
#ifndef ROOT_TMVA_GiniIndex
#include "TMVA/GiniIndex.h"
#endif
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
#ifndef ROOT_TMVA_IPruneTool
#include "TMVA/IPruneTool.h"
#endif

namespace TMVA {

   class CostComplexityPruneTool : public IPruneTool {
   public:
      CostComplexityPruneTool( SeparationBase* qualityIndex = NULL );
      virtual ~CostComplexityPruneTool( );

      // calculate the prune sequence for a given tree
      virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = NULL, Bool_t isAutomatic = kFALSE );

   private:
      SeparationBase* fQualityIndexTool; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }

      std::vector<DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index
      std::vector<Double_t> fPruneStrengthList; //! map of alpha -> pruning index
      std::vector<Double_t> fQualityIndexList; //! map of R(T) -> pruning index

      Int_t fOptimalK; //! the optimal index of the prune sequence

   private:
      // set the meta data used for cost complexity pruning
      void InitTreePruningMetaData( DecisionTreeNode* n );

      // optimize the pruning sequence
      void Optimize( DecisionTree* dt, Double_t weights );

      mutable MsgLogger* fLogger; //! output stream to save logging information
      MsgLogger& Log() const { return *fLogger; }

   };
}


#endif
back to top