Skip to content
Snippets Groups Projects
Commit 260b048a authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

example of uploading grid search results

parent 4333f61b
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:9824f12e tags:
``` python
# code
```
%% Cell type:code id:fdc05dfa tags:
``` python
parameters = {'max_depth':(5,10,15)}
rf = RandomForestRegressor(n_estimators=15, try_features = 'third')
searcher = GridSearchCV(rf, parameters, cv=3)
```
%% Cell type:code id:4f464219 tags:
``` python
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
```
%% Cell type:code id:25a04ec5 tags:
``` python
X, y = make_regression(n_features=1, n_informative=1, random_state=0, shuffle=False, noise=10)
```
%% Cell type:code id:24b05b14 tags:
``` python
plt.scatter(X, y);
```
%% Output
%% Cell type:code id:8532c5eb tags:
``` python
rf = RandomForestRegressor()
```
%% Cell type:code id:528b3452 tags:
``` python
searcher = GridSearchCV(estimator=rf, param_grid={'max_depth':(5,10,15)}, cv= 3)
```
%% Cell type:code id:ec7bef2c tags:
``` python
searcher.fit(X, y)
```
%% Output
GridSearchCV(cv=3, estimator=RandomForestRegressor(),
param_grid={'max_depth': (5, 10, 15)})
%% Cell type:code id:a20fe377 tags:
``` python
searcher.cv_results_
```
%% Output
{'mean_fit_time': array([0.08331521, 0.07211264, 0.0718383 ]),
'std_fit_time': array([0.01629742, 0.00011223, 0.00013308]),
'mean_score_time': array([0.00523392, 0.00518481, 0.00502459]),
'std_score_time': array([1.56304468e-04, 1.14600801e-04, 5.38271269e-05]),
'param_max_depth': masked_array(data=[5, 10, 15],
mask=[False, False, False],
fill_value='?',
dtype=object),
'params': [{'max_depth': 5}, {'max_depth': 10}, {'max_depth': 15}],
'split0_test_score': array([0.92873148, 0.92620142, 0.92348871]),
'split1_test_score': array([0.90228693, 0.90530648, 0.90144398]),
'split2_test_score': array([0.89852812, 0.8944555 , 0.89516853]),
'mean_test_score': array([0.90984884, 0.90865446, 0.9067004 ]),
'std_test_score': array([0.01343993, 0.01317466, 0.01214443]),
'rank_test_score': array([1, 2, 3], dtype=int32)}
%% Cell type:code id:f58df0df tags:
``` python
for i, p in enumerate(searcher.cv_results_['params']):
for parname, parvalue in p.items():
print(parname, parvalue)
print(searcher.cv_results_['mean_test_score'][i])
print(searcher.cv_results_['mean_fit_time'][i])
```
%% Output
max_depth 5
0.9098488414768675
0.08331521352132161
max_depth 10
0.908654464483059
0.07211263974507649
max_depth 15
0.9067004042768522
0.0718382994333903
%% Cell type:code id:f44a0964 tags:
``` python
```
%% Cell type:markdown id:6168dc37 tags:
## Log search results to model repo
%% Cell type:code id:1c8a9237 tags:
``` python
import mlflow
```
%% Cell type:code id:19952b7e tags:
``` python
mlflow.set_tracking_uri('https://modelrepository.eflows4hpc.eu/')
```
%% Cell type:code id:c5f64c94 tags:
``` python
mlflow.set_experiment('gridsearch-example')
```
%% Output
2024/01/15 10:53:31 INFO mlflow.tracking.fluent: Experiment with name 'gridsearch-example' does not exist. Creating a new experiment.
<Experiment: artifact_location='mlflow-artifacts:/2', creation_time=1705312411739, experiment_id='2', last_update_time=1705312411739, lifecycle_stage='active', name='gridsearch-example', tags={}>
%% Cell type:code id:544a9c35 tags:
``` python
metrics=['mean_test_score', 'mean_fit_time']
for i, p in enumerate(searcher.cv_results_['params']):
with mlflow.start_run():
for parname, parvalue in p.items():
mlflow.log_param(parname, value=parvalue)
for m in metrics:
mlflow.log_metric(m, searcher.cv_results_[m][i])
```
%% Cell type:code id:899abf9c tags:
``` python
```
%% Cell type:code id:8afb1656 tags:
``` python
```
%% Cell type:code id:9ace8a2d tags:
``` python
```
%% Cell type:markdown id:a9c1680f tags:
## Serialize results
%% Cell type:code id:a874f869 tags:
``` python
import pandas as pd
```
%% Cell type:code id:74f4b85f tags:
``` python
df = pd.DataFrame.from_dict(searcher.cv_results_)
```
%% Cell type:code id:68f37d48 tags:
``` python
df
```
%% Output
mean_fit_time std_fit_time mean_score_time std_score_time \
0 0.083315 0.016297 0.005234 0.000156
1 0.072113 0.000112 0.005185 0.000115
2 0.071838 0.000133 0.005025 0.000054
param_max_depth params split0_test_score split1_test_score \
0 5 {'max_depth': 5} 0.928731 0.902287
1 10 {'max_depth': 10} 0.926201 0.905306
2 15 {'max_depth': 15} 0.923489 0.901444
split2_test_score mean_test_score std_test_score rank_test_score
0 0.898528 0.909849 0.013440 1
1 0.894456 0.908654 0.013175 2
2 0.895169 0.906700 0.012144 3
%% Cell type:code id:630a4d63 tags:
``` python
df.to_csv('search.results')
```
%% Cell type:code id:a604829e tags:
``` python
! ls search.results
```
%% Output
search.results
%% Cell type:code id:7b090b0b tags:
``` python
```
%% Cell type:code id:c4c2addd tags:
``` python
```
%% Cell type:code id:79b51d03 tags:
``` python
import json
import pandas as pd
```
%% Cell type:code id:a533b1fb tags:
``` python
df = pd.read_csv('search.results', index_col=0)
```
%% Cell type:code id:09989aa2 tags:
``` python
dct = df.to_dict()
```
%% Cell type:code id:7ffcd844 tags:
``` python
mlflow.set_experiment('serializedgridsearch-example')
```
%% Output
2024/01/15 11:19:33 INFO mlflow.tracking.fluent: Experiment with name 'serializedgridsearch-example' does not exist. Creating a new experiment.
<Experiment: artifact_location='mlflow-artifacts:/3', creation_time=1705313973974, experiment_id='3', last_update_time=1705313973974, lifecycle_stage='active', name='serializedgridsearch-example', tags={}>
%% Cell type:code id:1f5a6864 tags:
``` python
metrics=['mean_test_score', 'mean_fit_time']
for i, p in enumerate(dct['params'].values()):
with mlflow.start_run():
p = json.loads(p.replace('\'', '"'))
for parname, parvalue in p.items():
mlflow.log_param(parname, value=parvalue)
for m in metrics:
mlflow.log_metric(m, searcher.cv_results_[m][i])
```
%% Cell type:code id:976d5e87 tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment