https://github.com/estherjulien/HybridML
Raw File
Tip revision: 9985e6d930e8b98eb03330a964c2c3fc8788630c authored by estherjulien on 01 August 2022, 11:54:59 UTC
HybridCode deleted from test_data_gen.py
Tip revision: 9985e6d
learn_multi_class.py
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

import pandas as pd
import joblib

'''
Code for training a random forest for predicting cherries
'''


def train_cherries_rf(X, Y, name=None):
    if name is None:
        model_name = f"RFModels/rf_cherries.joblib"
    else:
        model_name = f"RFModels/rf_cherries_{name}.joblib"

    # split data in train and validation
    X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.1)

    # train model
    rf = RandomForestClassifier()
    rf.fit(X_train, Y_train)

    # evaluation
    score_rf = rf.score(X_val, Y_val)
    score_rf_no_cher = rf.score(X_val[Y_val[0] == 1], Y_val[Y_val[0] == 1])
    score_rf_cher = rf.score(X_val[Y_val[1] == 1], Y_val[Y_val[1] == 1])
    score_rf_ret_cher = rf.score(X_val[Y_val[2] == 1], Y_val[Y_val[2] == 1])
    score_rf_no_ret_cher = rf.score(X_val[Y_val[3] == 1], Y_val[Y_val[3] == 1])

    print(f"\nRF overall validation accuracy = {score_rf}")
    print(f"RF no cherry validation accuracy = {score_rf_no_cher}")
    print(f"RF cherry validation accuracy = {score_rf_cher}")
    print(f"RF ret cherry validation accuracy = {score_rf_ret_cher}")
    print(f"RF no ret cherry validation accuracy = {score_rf_no_ret_cher}")

    # feature importance
    feature_importance = pd.Series(rf.feature_importances_, index=X.columns)
    print("Feature importance:\n")
    print(feature_importance)

    # save
    joblib.dump(rf, model_name)

    return [score_rf, score_rf_no_cher, score_rf_cher, score_rf_ret_cher, score_rf_no_ret_cher]
back to top