compactem.oracles package
Various oracle learners.
Submodules
compactem.oracles.oracle_learners module
- get_calibrated_gbm(X, y, num_boosting_rounds=200, max_depth=5, learning_rate=0.1, early_stopping_rounds=3, val_pct=0.2, cal_pct=0.2, calibration_method='sigmoid', random_state=None)
Create a gradient boosting model (GBM) using LightGBM. This is intended for use as an oracle model. Many of these parameters directly map to LightGBM parameters.
- Parameters
X – 2D array of data to construct oracle on.
y – corresponding labels.
num_boosting_rounds – maximum number of boosting rounds to be used to construct the classifier.
max_depth – max. depth per boosting round tree.
learning_rate – learning rate fr LightGBM’s gbdt implementation.
early_stopping_rounds – stop model construction if validation error does not decrease for these many rounds; this supersedes num_boosting_rounds.
val_pct – validation set size as percentage of the overall data.
cal_pct – calibration set size as a percentage of the overall data.
calibration_method – valid calibration_method parameter as in scikit.
random_state – random seed.
- Returns
calibrated GBM classifier.
- get_calibrated_rf(X, y, params_range=None, base_est_params=None, cv_folds=3, cal_pct=0.2, calibration_method='sigmoid', random_state=None)
Creates a Random Forest oracle using scikits implementation.
- Parameters
X – 2D array of data to construct oracle on.
y – corresponding labels.
params_range (
Optional
[dict
]) – model selection param range for scikit Random Forests in the format that grid search accepts. For ex.,{'max_depth': [1, 2, 3, 4, 5], 'n_estimators': [5, 10, 50]}
.base_est_params (
Optional
[dict
]) – params to initialize the base Random Forest that would be fed into the cross-validation function. Should have parameters that the Random Forest classifier accepts.cv_folds – number of cross validation folds to use.
cal_pct – calibration set size as a percentage of the overall data.
calibration_method –
valid calibration_method parameter as in scikit.
random_state – random seed.
- Returns
calibrated Random Forest classifier.