Come tracciare un albero di esempio da randomForest :: getTree ()? [chiuso]


62

Chiunque ha ricevuto suggerimenti sulla libreria o sul codice su come tracciare effettivamente un paio di alberi di esempio da:

getTree(rfobj, k, labelVar=TRUE)

(Sì, lo so che non dovresti farlo operativamente, RF è una scatola nera, ecc. Ecc. Voglio controllare visivamente la sanità mentale di un albero per vedere se qualche variabile si comporta in modo controintuitivo, ha bisogno di modificare / combinare / discretizzazione / trasformazione, controllare quanto bene funzionano i miei fattori codificati, ecc.)


Domande precedenti senza una risposta decente:

In realtà voglio tracciare un albero campione . Quindi non discuterne già con me. Non sto chiedendo informazioni su varImpPlot(Variable Importance Plot) o partialPloto MDSPlot, o su questi altri grafici , li ho già, ma non sono un sostituto per vedere un albero di esempio. Sì, posso controllare visivamente l'output di getTree(...,labelVar=TRUE).

(Immagino che un plot.rf.tree()contributo sarebbe molto ben accolto.)


6
Non vedo la necessità di essere preventivamente polemici, specialmente se stai chiedendo a qualcuno di fare volontariato per aiutarti; non si imbatte bene. CV ha una politica di etichetta - potresti voler leggere le nostre FAQ .
gung - Ripristina Monica

9
@gung: ogni domanda precedente su questo argomento è decaduta in persone insistendo sul fatto che non era necessario, e in effetti eretico, tracciare un albero campione. Leggi le citazioni che ho dato. Sto cercando uno schizzo qui su come codificare la trama di un albero RF.
smci,

3
Vedo alcune risposte in cui gli utenti stanno cercando di essere utili e rispondono alla domanda, insieme ad alcuni commenti che mettono in discussione la premessa dell'idea (che, onestamente, credo siano intesi anche in uno spirito utile). È certamente possibile riconoscere che alcune persone non saranno d'accordo senza essere testimoni.
gung - Ripristina Monica

4
Vedo zero risposte in cui qualcuno ha mai tracciato un albero, in oltre un anno. Sto cercando una risposta specifica a quella domanda specifica.
smci,

1
È possibile tracciare un singolo albero creato con cforest(nel pacchetto party ). Altrimenti, dovrai convertire l'oggetto data.framerestituito randomForest::getTreein un treeoggetto simile.
chl

Risposte:


44

Prima (e più semplice) soluzione: se non si desidera attenersi alla RF classica, come implementato in Andy Liaw's randomForest, è possibile provare il pacchetto party che fornisce un'implementazione diversa dell'algoritmo RF originale (uso di alberi condizionali e schema di aggregazione basato sulla media del peso unitario). Quindi, come riportato in questo post R-help , è possibile tracciare un singolo membro dell'elenco di alberi. Sembra funzionare senza intoppi, per quanto ne so. Di seguito è riportato un diagramma di un albero generato da cforest(Species ~ ., data=iris, controls=cforest_control(mtry=2, mincriterion=0)).

inserisci qui la descrizione dell'immagine

In secondo luogo (quasi facile) Soluzione: La maggior parte delle tecniche di tree-based in R ( tree, rpart, TWIX, ecc) offre una treestruttura -come per la stampa / plottaggio un singolo albero. L'idea sarebbe quella di convertire l'output di randomForest::getTreeun tale oggetto R, anche se non ha senso dal punto di vista statistico. Fondamentalmente, è facile accedere alla struttura ad albero da un treeoggetto, come mostrato di seguito. Si noti che differirà leggermente a seconda del tipo di attività - regressione vs. classificazione - dove in un secondo momento aggiungerà probabilità specifiche della classe come ultima colonna di obj$frame(che è una data.frame).

> library(tree)
> tr <- tree(Species ~ ., data=iris)
> tr
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

 1) root 150 329.600 setosa ( 0.33333 0.33333 0.33333 )  
   2) Petal.Length < 2.45 50   0.000 setosa ( 1.00000 0.00000 0.00000 ) *
   3) Petal.Length > 2.45 100 138.600 versicolor ( 0.00000 0.50000 0.50000 )  
     6) Petal.Width < 1.75 54  33.320 versicolor ( 0.00000 0.90741 0.09259 )  
      12) Petal.Length < 4.95 48   9.721 versicolor ( 0.00000 0.97917 0.02083 )  
        24) Sepal.Length < 5.15 5   5.004 versicolor ( 0.00000 0.80000 0.20000 ) *
        25) Sepal.Length > 5.15 43   0.000 versicolor ( 0.00000 1.00000 0.00000 ) *
      13) Petal.Length > 4.95 6   7.638 virginica ( 0.00000 0.33333 0.66667 ) *
     7) Petal.Width > 1.75 46   9.635 virginica ( 0.00000 0.02174 0.97826 )  
      14) Petal.Length < 4.95 6   5.407 virginica ( 0.00000 0.16667 0.83333 ) *
      15) Petal.Length > 4.95 40   0.000 virginica ( 0.00000 0.00000 1.00000 ) *
> tr$frame
            var   n        dev       yval splits.cutleft splits.cutright yprob.setosa yprob.versicolor yprob.virginica
1  Petal.Length 150 329.583687     setosa          <2.45           >2.45   0.33333333       0.33333333      0.33333333
2        <leaf>  50   0.000000     setosa                                  1.00000000       0.00000000      0.00000000
3   Petal.Width 100 138.629436 versicolor          <1.75           >1.75   0.00000000       0.50000000      0.50000000
6  Petal.Length  54  33.317509 versicolor          <4.95           >4.95   0.00000000       0.90740741      0.09259259
12 Sepal.Length  48   9.721422 versicolor          <5.15           >5.15   0.00000000       0.97916667      0.02083333
24       <leaf>   5   5.004024 versicolor                                  0.00000000       0.80000000      0.20000000
25       <leaf>  43   0.000000 versicolor                                  0.00000000       1.00000000      0.00000000
13       <leaf>   6   7.638170  virginica                                  0.00000000       0.33333333      0.66666667
7  Petal.Length  46   9.635384  virginica          <4.95           >4.95   0.00000000       0.02173913      0.97826087
14       <leaf>   6   5.406735  virginica                                  0.00000000       0.16666667      0.83333333
15       <leaf>  40   0.000000  virginica                                  0.00000000       0.00000000      1.00000000

Quindi, ci sono metodi per stampare e stampare graziosamente quegli oggetti. Le funzioni chiave sono un tree:::plot.treemetodo generico (ho messo una tripla :che consente di visualizzare direttamente il codice in R) basandosi su tree:::treepl(visualizzazione grafica) e tree:::treeco(calcolare le coordinate dei nodi). Queste funzioni prevedono la obj$framerappresentazione dell'albero. Altre questioni sottili: (1) l'argomento type = c("proportional", "uniform")nel metodo di stampa predefinito tree:::plot.tree, aiuta a gestire la distanza verticale tra i nodi ( proportionalsignifica che è proporzionale alla devianza, uniformsignifica che è stato risolto); (2) devi integrare plot(tr)una chiamata per text(tr)aggiungere etichette di testo a nodi e divisioni, il che in questo caso significa che dovrai anche dare un'occhiata tree:::text.tree.

Il getTreemetodo da randomForestrestituisce una struttura diversa, che è documentata nella guida in linea. Di seguito viene mostrato un tipico output, con i nodi terminali indicati dal statuscodice (-1). (Anche in questo caso, l'output differirà in base al tipo di attività, ma solo nelle colonne statuse prediction.)

> library(randomForest)
> rf <- randomForest(Species ~ ., data=iris)
> getTree(rf, 1, labelVar=TRUE)
   left daughter right daughter    split var split point status prediction
1              2              3 Petal.Length        4.75      1       <NA>
2              4              5 Sepal.Length        5.45      1       <NA>
3              6              7  Sepal.Width        3.15      1       <NA>
4              8              9  Petal.Width        0.80      1       <NA>
5             10             11  Sepal.Width        3.60      1       <NA>
6              0              0         <NA>        0.00     -1  virginica
7             12             13  Petal.Width        1.90      1       <NA>
8              0              0         <NA>        0.00     -1     setosa
9             14             15  Petal.Width        1.55      1       <NA>
10             0              0         <NA>        0.00     -1 versicolor
11             0              0         <NA>        0.00     -1     setosa
12            16             17 Petal.Length        5.40      1       <NA>
13             0              0         <NA>        0.00     -1  virginica
14             0              0         <NA>        0.00     -1 versicolor
15             0              0         <NA>        0.00     -1  virginica
16             0              0         <NA>        0.00     -1 versicolor
17             0              0         <NA>        0.00     -1  virginica

Se si riesce a convertire la tabella di cui sopra a quello generato da tree, si sarà probabilmente in grado di personalizzare tree:::treepl, tree:::treecoe tree:::text.treein base alle proprie esigenze, anche se non ho un esempio di questo approccio. In particolare, probabilmente vorrai sbarazzarti dell'uso della devianza, delle probabilità di classe, ecc. Che non sono significative in RF. Tutto quello che vuoi è impostare le coordinate dei nodi e dividere i valori. Potresti usarlo fixInNamespace()per questo, ma, ad essere sincero, non sono sicuro che sia la strada giusta da percorrere.

Terza (e sicuramente intelligente) soluzione: Scrivi una vera as.treefunzione di aiuto che allevierà tutte le "patch" sopra. È quindi possibile utilizzare i metodi di stampa di R o, probabilmente meglio, Klimt (direttamente da R) per visualizzare i singoli alberi.


40

Sono in ritardo di quattro anni, ma se vuoi davvero attenermi al randomForestpacchetto (e ci sono alcuni buoni motivi per farlo) e vuoi effettivamente visualizzare l'albero, puoi usare il pacchetto reprtree .

Il pacchetto non è ben documentato (puoi trovare i documenti qui ), ma tutto è abbastanza semplice. Per installare il pacchetto, fare riferimento a initialize.R nel repository, quindi eseguire semplicemente quanto segue:

options(repos='http://cran.rstudio.org')
have.packages <- installed.packages()
cran.packages <- c('devtools','plotrix','randomForest','tree')
to.install <- setdiff(cran.packages, have.packages[,1])
if(length(to.install)>0) install.packages(to.install)

library(devtools)
if(!('reprtree' %in% installed.packages())){
  install_github('araastat/reprtree')
}
for(p in c(cran.packages, 'reprtree')) eval(substitute(library(pkg), list(pkg=p)))

Quindi vai avanti e crea il tuo modello e albero:

library(randomForest)
library(reprtree)

model <- randomForest(Species ~ ., data=iris, importance=TRUE, ntree=500, mtry = 2, do.trace=100)

reprtree:::plot.getTree(model)

E il gioco è fatto! Bello e semplice.

albero generato da plot.getTree (modello)

Puoi controllare il repository github per conoscere gli altri metodi nel pacchetto. In effetti, se controlli plot.getTree.R , noterai che l'autore usa la propria implementazione di as.tree()cui chl ♦ ha suggerito che potresti costruirti nella sua risposta. Questo significa che potresti farlo:

tree <- getTree(model, k=1, labelVar=TRUE)
realtree <- reprtree:::as.tree(tree, model)

E quindi potenzialmente utilizzare realtreecon altri pacchetti per la stampa di alberi come tree .


Grazie mille, sto ancora accettando felicemente le risposte, questa sembra essere un'area in cui le persone non sono distinte dalle offerte. Immagino che anche la nuova novità sarebbe di supportare xgboost.
smci,

6
nessun problema. Mi ci sono volute ore per trovare la libreria / pacchetto, quindi ho pensato che se non fosse utile per te, sarebbe stato per altre persone che cercavano di disegnare alberi mentre rimanevano ancora sul randomForestpacchetto.
jgozal,

2
Fantastica scoperta. Nota: Traccia l'albero rappresentativo, in un certo senso, l'albero dell'ensemble che sono in media il "più vicino" a tutti gli altri alberi dell'ensemble
Chris

2
@Chris La funzione plot.getTree()traccia un singolo albero. La funzione plot.reprtree()in quel pacchetto traccia un albero rappresentativo.
Chun Li

1
ho preso il modello dal cursore e voglio inserire un reptree con reprtree:::plot.getTree(mod_rf_1$finalModel), tuttavia, c'è un "Errore in data.frame (var = fr $ var, splits = as.character (gTree [," split point "]),: gli argomenti implicano differenze numero di file: 2631, 0 "
HappyCoding

15

Ho creato alcune funzioni per estrarre le regole di un albero.

#**************************
#return the rules of a tree
#**************************
getConds<-function(tree){
  #store all conditions into a list
  conds<-list()
  #start by the terminal nodes and find previous conditions
  id.leafs<-which(tree$status==-1)
	  j<-0
	  for(i in id.leafs){
		j<-j+1
		prevConds<-prevCond(tree,i)
		conds[[j]]<-prevConds$cond
		while(prevConds$id>1){
		  prevConds<-prevCond(tree,prevConds$id)
		  conds[[j]]<-paste(conds[[j]]," & ",prevConds$cond)
        }
		if(prevConds$id==1){
			conds[[j]]<-paste(conds[[j]]," => ",tree$prediction[i])
    }
    }

  }

  return(conds)
}

#**************************
#find the previous conditions in the tree
#**************************
prevCond<-function(tree,i){
  if(i %in% tree$right_daughter){
		id<-which(tree$right_daughter==i)
		cond<-paste(tree$split_var[id],">",tree$split_point[id])
	  }
	  if(i %in% tree$left_daughter){
    id<-which(tree$left_daughter==i)
		cond<-paste(tree$split_var[id],"<",tree$split_point[id])
  }

  return(list(cond=cond,id=id))
}

#remove spaces in a word
collapse<-function(x){
  x<-sub(" ","_",x)

  return(x)
}


data(iris)
require(randomForest)
mod.rf <- randomForest(Species ~ ., data=iris)
tree<-getTree(mod.rf, k=1, labelVar=TRUE)
#rename the name of the column
colnames(tree)<-sapply(colnames(tree),collapse)
rules<-getConds(tree)
print(rules)
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.