Come estrarre le regole decisionali dall'albero decisionale di scikit-learn?


157

Posso estrarre le regole decisionali sottostanti (o "percorsi decisionali") da un albero addestrato in un albero decisionale come un elenco testuale?

Qualcosa di simile a:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Grazie per l'aiuto.



Hai mai trovato una risposta a questo problema? Devo esportare le regole dell'albero decisionale in un formato di passaggio dati SAS che è quasi esattamente come lo hai elencato.
Zelazny7,

1
Puoi usare il pacchetto sklearn-porter per esportare e traspilare alberi decisionali (anche foreste casuali e alberi potenziati) in C, Java, JavaScript e altri.
Dario

Puoi controllare questo link- kdnuggets.com/2017/05/…
yogesh agrawal

Risposte:


139

Credo che questa risposta sia più corretta delle altre risposte qui:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Questo stampa una funzione Python valida. Ecco un esempio di output per un albero che sta cercando di restituire il suo input, un numero compreso tra 0 e 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Ecco alcuni ostacoli che vedo in altre risposte:

  1. Usare tree_.threshold == -2per decidere se un nodo è una foglia non è una buona idea. E se fosse un vero nodo decisionale con una soglia di -2? Invece, dovresti guardare tree.featureo tree.children_*.
  2. La linea si features = [feature_names[i] for i in tree_.feature]arresta in modo anomalo con la mia versione di sklearn, perché alcuni valori di tree.tree_.featuresono -2 (in particolare per i nodi foglia).
  3. Non è necessario avere più istruzioni if ​​nella funzione ricorsiva, solo una va bene.

1
Questo codice funziona alla grande per me. Tuttavia, ho 500+ feature_names, quindi il codice di output è quasi impossibile da comprendere per un essere umano. C'è un modo per consentirmi di inserire nella funzione solo i feature_names di cui sono curioso?
user3768495

1
Sono d'accordo con il commento precedente. IIUC, print "{}return {}".format(indent, tree_.value[node])deve essere modificato in print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))affinché la funzione restituisca l'indice di classe.
soupault,

1
@paulkernfeld Ah sì, vedo che puoi fare un giro RandomForestClassifier.estimators_, ma non sono riuscito a capire come combinare i risultati degli stimatori.
Nathan Lloyd,

6
Non riuscivo a farlo funzionare in Python 3, i bit _tree non sembrano mai funzionare e TREE_UNDEFINED non è stato definito. Questo link mi ha aiutato. Sebbene il codice esportato non sia eseguibile direttamente in Python, è simile a C e abbastanza facile da tradurre in altre lingue: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah

1
@Josiah, aggiungi () alle istruzioni di stampa per farlo funzionare in python3. ad esempio print "bla"=>print("bla")
Nir

48

Ho creato la mia funzione per estrarre le regole dagli alberi decisionali creati da sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Questa funzione inizia dapprima con i nodi (identificati da -1 negli array secondari) e quindi trova ricorsivamente i genitori. Lo chiamo un "lignaggio" di un nodo. Lungo la strada, afferro i valori di cui ho bisogno per creare la logica SAS if / then / else:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Gli insiemi di tuple sottostanti contengono tutto ciò di cui ho bisogno per creare istruzioni if ​​/ then / else SAS. Non mi piace usare i doblocchi in SAS, motivo per cui creo la logica che descrive l'intero percorso di un nodo. Il singolo numero intero dopo le tuple è l'ID del nodo terminale in un percorso. Tutte le tuple precedenti si combinano per creare quel nodo.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Output GraphViz dell'albero di esempio


questo tipo di albero è corretto perché col1 sta tornando di nuovo uno è col1 <= 0,50000 e uno col1 <= 2,5000 se sì, è questo qualsiasi tipo di ricorsione che viene usata nella libreria
jayant singh

il ramo giusto avrebbe record tra (0.5, 2.5]. Gli alberi sono realizzati con partizioni ricorsive. Non c'è nulla che impedisca a una variabile di essere selezionata più volte.
Zelazny7

okay puoi spiegare la parte ricorsiva cosa succede xactly perché l'ho usato nel mio codice e si vede un risultato simile
jayant singh

38

Ho modificato il codice inviato da Zelazny7 per stampare alcuni pseudocodici:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

se chiami get_code(dt, df.columns)sullo stesso esempio otterrai:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
Puoi dire cosa significa esattamente [[1. 0.]] nell'istruzione return nell'output sopra. Non sono un ragazzo di Python, ma sto lavorando allo stesso genere di cose. Quindi sarà bene per me se per favore prova alcuni dettagli in modo che sia più facile per me.
Subhradip Bose,

1
@ user3156186 Significa che esiste un oggetto nella classe '0' e zero oggetti nella classe '1'
Daniele

1
@Daniele, sai come vengono ordinate le lezioni? Immagino alfanumerico, ma non ho trovato conferma da nessuna parte.
IanS,

Grazie! Per lo scenario del caso limite in cui il valore di soglia è in realtà -2, potrebbe essere necessario passare (threshold[node] != -2)a ( left[node] != -1)(simile al metodo seguente per ottenere gli ID dei nodi figlio)
tlingf,

@Daniele, hai idea di come rendere la tua funzione "get_code" "restituire" un valore e non "stamparlo", perché devo inviarlo a un'altra funzione?
RoyaumeIX,

17

Scikit Learn ha introdotto un delizioso nuovo metodo chiamato export_textnella versione 0.21 (maggio 2019) per estrarre le regole da un albero. Documentazione qui . Non è più necessario creare una funzione personalizzata.

Dopo aver adattato il tuo modello, hai solo bisogno di due righe di codice. Innanzitutto, importa export_text:

from sklearn.tree.export import export_text

Secondo, crea un oggetto che conterrà le tue regole. Per rendere le regole più leggibili, utilizzare l' feature_namesargomento e passare un elenco dei nomi delle funzionalità. Ad esempio, se il tuo modello viene chiamato modele le tue caratteristiche sono denominate in un dataframe chiamato X_train, puoi creare un oggetto chiamato tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Quindi basta stampare o salvare tree_rules. L'output sarà simile al seguente:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

C'è un nuovo DecisionTreeClassifiermetodo decision_path, nella versione 0.18.0 . Gli sviluppatori forniscono un ampio (ben documentata) progressione .

La prima sezione di codice nella procedura dettagliata che stampa la struttura ad albero sembra essere OK. Tuttavia, ho modificato il codice nella seconda sezione per interrogare un campione. I miei cambiamenti sono indicati con# <--

Modifica Le modifiche contrassegnate dal # <--codice seguente sono state aggiornate nel collegamento dettagliato dopo che gli errori sono stati segnalati nelle richieste pull # 8653 e # 10951 . È molto più facile seguirlo ora.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Modificare il sample_idper visualizzare i percorsi decisionali per altri campioni. Non ho chiesto agli sviluppatori di questi cambiamenti, mi è sembrato più intuitivo quando ho lavorato sull'esempio.


tu amico mio sei una leggenda! qualche idea su come tracciare l'albero decisionale per quel campione specifico? molto aiuto è apprezzato

1
Grazie Victor, probabilmente è meglio porre questa domanda come una domanda separata poiché i requisiti di stampa possono essere specifici per le esigenze dell'utente. Probabilmente otterrai una buona risposta se fornisci un'idea di come vuoi che sia l'output.
Kevin,

Hey kevin, ho creato la questione stackoverflow.com/questions/48888893/...

saresti così gentile da dare un'occhiata a: stackoverflow.com/questions/52654280/...
Alexander Chervov

Puoi per favore spiegare la parte chiamata node_index, non ottenere quella parte. Che cosa fa?
Anindya Sankar Dey,

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Puoi vedere un albero digraph. Quindi, clf.tree_.featuree clf.tree_.valuesono array di nodi che dividono rispettivamente funzionalità e array di valori di nodi. Puoi fare riferimento a maggiori dettagli da questa fonte github .


1
Sì, so come disegnare l'albero - ma ho bisogno della versione più testuale - le regole. qualcosa del tipo: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman,

4

Solo perché tutti sono stati così utili, aggiungerò semplicemente una modifica alle meravigliose soluzioni di Zelazny7 e Daniele. Questo è per Python 2.7, con le schede per renderlo più leggibile:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

I codici di seguito sono il mio approccio con anaconda python 2.7 più un nome di pacchetto "pydot-ng" per creare un file PDF con regole di decisione. Spero sia utile

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

uno spettacolo al graphy sugli alberi qui


3

Ho passato questo, ma avevo bisogno che le regole fossero scritte in questo formato

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Quindi ho adattato la risposta di @paulkernfeld (grazie) che puoi personalizzare in base alle tue esigenze

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

Ecco un modo per tradurre l'intero albero in un'unica espressione (non necessariamente troppo leggibile dall'uomo) usando la libreria SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Questo si basa sulla risposta di @paulkernfeld. Se hai un frame di dati X con le tue caratteristiche e un frame di dati di destinazione y con le tue risonanze e vuoi avere un'idea di quale valore y sia finito in quale nodo (e anche una formica per tracciarlo di conseguenza) puoi fare quanto segue:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

non la versione più elegante ma fa il lavoro ...


1
Questo è un buon approccio quando si desidera restituire le righe di codice anziché semplicemente stamparle.
Hajar Homayouni,

3

Questo è il codice che ti serve

Ho modificato il codice che mi piaceva di più per rientrare correttamente in un notebook jupyter python 3

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

Ecco una funzione, che stampa le regole di un albero decisionale scikit-learning in Python 3 e con offset per blocchi condizionali per rendere più leggibile la struttura:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

Puoi anche renderlo più informativo distinguendolo a quale classe appartiene o anche menzionandone il valore di output.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

inserisci qui la descrizione dell'immagine


2

Ecco il mio approccio per estrarre le regole di decisione in una forma che può essere utilizzata direttamente in sql, in modo che i dati possano essere raggruppati per nodo. (Basato sugli approcci dei precedenti poster.)

Il risultato saranno CASEclausole successive che possono essere copiate in un'istruzione sql, ad es.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Ora puoi usare export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Un esempio completo di [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Codice di Zelazny7 modificato per recuperare SQL dall'albero decisionale.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Apparentemente molto tempo fa qualcuno aveva già deciso di provare ad aggiungere la seguente funzione alle funzioni di esportazione dell'albero dello scikit ufficiale (che sostanzialmente supporta solo export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Ecco il suo impegno completo:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Non sono sicuro di cosa sia successo a questo commento. Ma potresti anche provare a usare quella funzione.

Penso che ciò meriti una seria richiesta di documentazione alle brave persone di scikit-learning per documentare correttamente l' sklearn.tree.TreeAPI che è la struttura ad albero sottostante che DecisionTreeClassifierespone come suo attributo tree_.


0

Usa semplicemente la funzione di sklearn.tree in questo modo

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Quindi cerca nella cartella del tuo progetto il file tree.dot , copia TUTTO il contenuto e incollalo qui http://www.webgraphviz.com/ e genera il tuo grafico :)


0

Grazie per la meravigliosa soluzione di @paulkerfeld. In cima la sua soluzione, per tutti coloro che vogliono avere una versione serializzata di alberi, basta usare tree.threshold, tree.children_left, tree.children_right, tree.featuree tree.value. Dal momento che le foglie non hanno divisioni e quindi senza nomi di funzioni e figli, il loro segnaposto è tree.featuree tree.children_***sono _tree.TREE_UNDEFINEDe _tree.TREE_LEAF. A ogni divisione viene assegnato un indice univoco da depth first search.
Si noti che tree.valueè di forma[n, 1, 1]


0

Ecco una funzione che genera il codice Python da un albero decisionale convertendo l'output di export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Esempio di utilizzo:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Uscita campione:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

L'esempio sopra è generato con names = ['f'+str(j+1) for j in range(NUM_FEATURES)] .

Una caratteristica utile è che può generare file di dimensioni inferiori con una spaziatura ridotta. Basta impostare spacing=2.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.