La foresta casuale è troppo adatta?


19

Sto sperimentando foreste casuali con scikit-learn e sto ottenendo grandi risultati dal mio set di allenamento, ma risultati relativamente scarsi sul mio set di test ...

Ecco il problema (ispirato al poker) che sto cercando di risolvere: date le carte coperte del giocatore A, le carte coperte del giocatore B e un flop (3 carte), quale giocatore ha la mano migliore? Matematicamente, si tratta di 14 input (7 carte - un valore e un seme per ciascuno) e un output (0 o 1).

Ecco alcuni dei miei risultati finora:

Training set size: 600k, test set size: 120k, number of trees: 25
Success rate in training set: 99.975%
Success rate in testing set: 90.05%

Training set size: 400k, test set size: 80k, number of trees: 100
Success rate in training set: 100%
Success rate in testing set: 89.7%

Training set size: 600k, test set size: 120k, number of trees: 5
Success rate in training set: 98.685%
Success rate in testing set: 85.69%

Ecco il codice pertinente utilizzato:

from sklearn.ensemble import RandomForestClassifier
Forest = RandomForestClassifier(n_estimators = 25) #n_estimator varies
Forest = Forest.fit(inputs[:trainingSetSize],outputs[:trainingSetSize])
trainingOutputs = Forest.predict(inputs[:trainingSetSize])
testOutputs = Forest.predict(inputs[trainingSetSize:])

Sembra che, indipendentemente dal numero di alberi utilizzati, le prestazioni sul set di allenamento sono molto migliori rispetto al set di test, nonostante un set di allenamento relativamente grande e un numero ragionevolmente piccolo di funzionalità ...


2
Non vedo una dozzina di "quadranti" per foreste casuali qui. Convalida incrociata? Priori Bayesiani? Natura del ricampionamento? Set di allenamento per ogni albero? Quale percentuale del sottoinsieme per ogni albero? ... ce ne sono molti altri che possono essere elencati, ma il punto è che hai altri input da considerare.
EngrStudent - Ripristina Monica il

1
Potresti spiegare il problema per coloro che non conoscono il poker ... esiste un semplice calcolo per il punteggio del poker? quindi è più facile capire se c'è qualcosa di fondamentalmente sbagliato nell'uso della RF ... Non conosco il poker, ma sospetto che probabilmente la RF sia l'approccio sbagliato - vale a dire il primo passo nella RF è usare solo una frazione degli input, mentre mi sembra che non ci sia modo di costruire un buon classificatore usando solo un sottoinsieme degli input - tutti gli input sono richiesti.
seanv507,

Risposte:


45

Questo è un errore rookie comune quando si usano i modelli RF (alzerò la mano come autore precedente). La foresta che costruisci usando il set di addestramento in molti casi si adatterà ai dati di allenamento quasi perfettamente (come stai scoprendo) se considerata nella totalità. Tuttavia, poiché l'algoritmo crea la foresta, ricorda l'errore di previsione out-of-bag (OOB), che è la sua migliore ipotesi dell'errore di generalizzazione.

Se si rimandano i dati di allenamento nel metodo di previsione (mentre si sta facendo) si ottiene questa previsione quasi perfetta (che è decisamente ottimistica) invece dell'errore OOB corretto. Non farlo Invece, l'oggetto Foresta addestrato avrebbe dovuto ricordare al suo interno l'errore OOB. Non ho familiarità con l'implementazione di scikit-learn, ma guardando la documentazione qui sembra che sia necessario specificare oob_score=Truequando si chiama il metodo fit, e quindi l'errore di generalizzazione verrà archiviato comeoob_score_nell'oggetto restituito. Nel pacchetto R "randomForest", la chiamata al metodo predict senza argomenti sull'oggetto restituito restituirà la previsione OOB sul set di addestramento. Ciò consente di definire l'errore utilizzando un'altra misura. L'invio dell'allenamento impostato nel metodo predict ti darà un risultato diverso, in quanto utilizzerà tutti gli alberi. Non so se l' scikit-learnimplementazione lo farà o no.

È un errore restituire i dati di allenamento nel metodo previsto per testare l'accuratezza. È un errore molto comune, quindi non preoccuparti.


1
Grazie! Tuttavia, ho ancora una preoccupazione: con 400k esempi di allenamento e 50 alberi, ho ottenuto l'89,6% corretto, mentre con la quantità di dati e il doppio di alberi ho ottenuto l'89,7% corretto ... Questo suggerisce che la RF non è una buona metodo per questo? Ho usato una rete neurale MLP in passato e ho raggiunto un'accuratezza del 98,5% sul set di test ...
Uwat

5
È possibile, anche se sembra che tu non stia usando abbastanza alberi. In genere hai bisogno di migliaia. Nota che il numero di alberi non è un parametro da sintonizzare nell'algoritmo RF, più è sempre meglio, ma una volta che hai 'abbastanza' (per essere determinato empiricamente) l'errore OOB non migliora con più alberi. Anche per piccoli insiemi di dati semplici, niente di meno di 500 alberi non è quasi sufficiente.
Bogdanovist,

1
Ci sono alcuni piccoli avvertimenti per "più è sempre meglio" per il numero di alberi, ma ho capito che hai bisogno di gazjillions di alberi prima di iniziare a fare un colpo di prestazione. Nella mia esperienza, quanti alberi hai le risorse della CPU e la pazienza per generare il meglio, anche se con rendimenti decrescenti una volta che gli altopiani della curva OBB (ntrees).
Bogdanovist,

12

Penso che la risposta sia il parametro max_features: int, string o None, parametro facoltativo (default = "auto"). fondamentalmente per questo problema dovresti impostarlo su Nessuno, in modo che ogni albero sia costruito con tutti gli input, dato che chiaramente non puoi costruire un classificatore corretto usando solo una frazione delle carte (default "auto" sta selezionando sqrt (nfeatures) input per ogni albero)


1
Questo è stato! Precisione del 95% con 50 alberi e 600k esempi di allenamento.
Uwat,

2
Nota che a questo punto difficilmente stai usando una foresta casuale, ma come hanno affermato altre risposte, non è il classificatore ideale per questo esatto problema.
Richard Rast,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.