compactem.oracles package

Various oracle learners.

Submodules

compactem.oracles.oracle_learners module

get_calibrated_gbm(X, y, num_boosting_rounds=200, max_depth=5, learning_rate=0.1, early_stopping_rounds=3, val_pct=0.2, cal_pct=0.2, calibration_method='sigmoid', random_state=None)

Create a gradient boosting model (GBM) using LightGBM. This is intended for use as an oracle model. Many of these parameters directly map to LightGBM parameters.

Parameters
  • X – 2D array of data to construct oracle on.

  • y – corresponding labels.

  • num_boosting_rounds – maximum number of boosting rounds to be used to construct the classifier.

  • max_depth – max. depth per boosting round tree.

  • learning_rate – learning rate fr LightGBM’s gbdt implementation.

  • early_stopping_rounds – stop model construction if validation error does not decrease for these many rounds; this supersedes num_boosting_rounds.

  • val_pct – validation set size as percentage of the overall data.

  • cal_pct – calibration set size as a percentage of the overall data.

  • calibration_method – valid calibration_method parameter as in scikit.

  • random_state – random seed.

Returns

calibrated GBM classifier.

get_calibrated_rf(X, y, params_range=None, base_est_params=None, cv_folds=3, cal_pct=0.2, calibration_method='sigmoid', random_state=None)

Creates a Random Forest oracle using scikits implementation.

Parameters
  • X – 2D array of data to construct oracle on.

  • y – corresponding labels.

  • params_range (Optional[dict]) – model selection param range for scikit Random Forests in the format that grid search accepts. For ex., {'max_depth': [1, 2, 3, 4, 5], 'n_estimators': [5, 10, 50]}.

  • base_est_params (Optional[dict]) – params to initialize the base Random Forest that would be fed into the cross-validation function. Should have parameters that the Random Forest classifier accepts.

  • cv_folds – number of cross validation folds to use.

  • cal_pct – calibration set size as a percentage of the overall data.

  • calibration_method

    valid calibration_method parameter as in scikit.

  • random_state – random seed.

Returns

calibrated Random Forest classifier.