name: custom-sklearn-estimator description: Build scikit-learn compatible custom estimators by following the official “rolling your own estimator” rules for init, fit/predict, validation, learned attributes, tags, and estimator checks; prerequisite for autogluon-sklearn-wrapper or any sklearn-facing wrappers.
Custom scikit-learn Estimator
Purpose
Create scikit-learn compatible estimators that work with pipelines, model selection, and validation tooling. This skill codifies the required API patterns for __init__, fit, prediction/transform methods, learned attributes, and estimator checks.
Usage
- “rolling your own estimator”
- “custom scikit-learn estimator”
- “build sklearn-compatible class”
Instructions
- Choose the estimator type and mixins
- Use
ClassifierMixin,RegressorMixin,TransformerMixin, orClusterMixinas needed, withBaseEstimatorlast in the inheritance list. - For meta-estimators, ensure sub-estimator params are exposed through
get_params/set_params(handled byBaseEstimator).
- Use
- Implement a minimal
__init__- Keyword args with defaults; no validation or logic.
- Assign each parameter to an attribute with the exact same name.
- Avoid mutable defaults; do not set attributes with trailing
_here.
- Implement
fit- Signature:
fit(self, X, y=None, **kwargs)and accepty=Noneeven for unsupervised estimators. - Validate inputs using
validate_data/check_array; ensureX.shape[0] == y.shape[0]when supervised. - Set learned attributes with trailing
_(e.g.,coef_,classes_). - Return
selfand overwrite learned attributes on every call unlesswarm_start=True.
- Signature:
- Implement prediction/transform methods
- Call
check_is_fittedand validate inputs withvalidate_data(..., reset=False). - Classifiers must use
self.classes_and return labels, not indices. - Transformers must preserve sample count and order.
- Call
- Handle randomness correctly
- Accept
random_state=Nonein__init__, store it unmodified. - In
fit, usecheck_random_stateand store RNG inrandom_state_if needed later.
- Accept
- Optional: tags and
set_output- Implement
__sklearn_tags__if default tags are not appropriate. - For transformers, consider
get_feature_names_outandset_outputcompatibility.
- Implement
- Validate with estimator checks
- Run
check_estimatororparametrize_with_checkswhen possible. - Use the response checklist in
./templates/estimator-checklist.mdto confirm compliance.
- Run