Grid Search in scikit-learn

The performance of our Machine Learning model is largely based on the hyperparameter values for the model. Hence, hyperparameter tuning is a significant step in order to determine the optimal values for our model. This is what is done using Grid Search. Thus, grid search performs an exhaustive search over parameter values specified to find optimal parameters for an estimator.

Sklearn’s GridSearchCV function loops through predefined hyperparameters. It fits the model on the training dataset and selects the most optimal parameters for the number of cross-validation times.

Python Example

In Python, grid search is performed using the scikit-learn library’s sklearn.model_selection.GridSearchCV function. Here, we will work with the sklearn’s wine dataset to look into tuning hyperparameters for our model.

The first step is to load the dataset:

from sklearn.datasets import load_wine
wine = load_wine()

This is a simple multi-class classification dataset for wine recognition. It has three classes and the number of instances is equal to 178. It has 13 attributes.

Now, split the data into training and test samples:

from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(wine.data, wine.target, test_size=0.3, random_state=0)

This splits the wine dataset into training and test sets in a ratio of 70:30.

Now let’s create a model with Decision Tree Classifier to classify the wine data, as:

from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier()

Now we define the hyperparameters for DecisionTreeClassifier which we want to try out:

parameters = {'splitter' : ['best', 'random'],
'criterion' : ['gini', 'entropy'],
'max_features': ['log2', 'sqrt','auto'],
'max_depth': [2, 3, 5, 10, 17],
'min_samples_split': [2, 3, 5, 7, 9],
'min_samples_leaf': [1,5,8,11],
'random_state' : [0,1,2,3,4,5]
}

These parameters will depend on the type of model/estimator we are using. The dictionary parameters holds all the hyperparameters which we want to test out.

Now let’s create an object of GridSearchCV, as:

from sklearn.model_selection import GridSearchCV

grid_search_dt = GridSearchCV(estimator = model,
param_grid = parameters,
scoring = 'accuracy',
cv = 5,
verbose = 1)

estimator is the model, which is defined as DecisionTreeClassifier() above. param_grid takes the parameter dictionary. scoring defines the type of evaluation metric which we want to use. cv is used to specify the number of cross-validations. verbose=1 gives detailed output then the data is fit to GridSearchCV.

Fit the data into the GridSearchCV object:

grid_search_dt.fit(x_train, y_train)

Now, this starts giving output on the screen. While the data is being fitted, it gives something like:

Fitting 5 folds for each of 7200 candidates, totalling 36000 fits
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 36000 out of 36000 | elapsed:   36.2s finished
GridSearchCV(cv=5, estimator=DecisionTreeClassifier(),
             param_grid={'criterion': ['gini', 'entropy'],
                         'max_depth': [2, 3, 5, 10, 17],
                         'max_features': ['log2', 'sqrt', 'auto'],
                         'min_samples_leaf': [1, 5, 8, 11],
                         'min_samples_split': [2, 3, 5, 7, 9],
                         'random_state': [0, 1, 2, 3, 4, 5],
                         'splitter': ['best', 'random']},
             scoring='accuracy', verbose=1)

We can now extract the best estimator as:

print(grid_search_dt.best_estimator_)

Output:

DecisionTreeClassifier(max_depth=5, max_features='log2', min_samples_split=7, random_state=3)

To get the score of the optimal model on the testing data:

print(grid_search_dt.score(x_test, y_test))

Output:

0.9629629629629629

This gives the accuracy of the model on testing data.

Summary

In this article, we looked at Grid Search in scikit learn. In the next article, we will focus on Recursive Feature Elimination and SelectKBest features.

Leave a Reply

Your email address will not be published.