From da02018a0b3739d9903d6f2c996a53affea474bf Mon Sep 17 00:00:00 2001 From: Simon Blanke Date: Sat, 12 Oct 2024 08:59:18 +0200 Subject: [PATCH] adapt 'validate_data' future sklearn versions --- .../integrations/sklearn/hyperactive_search_cv.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py b/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py index 7de054bb..ad4e5d5e 100644 --- a/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py +++ b/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py @@ -86,6 +86,15 @@ def _refit(self, X, y=None, **fit_params): self.best_estimator_.fit(X, y, **fit_params) return self + def _check_data(self, X, y): + X, y = indexable(X, y) + if hasattr(self, "_validate_data"): + validate_data = self._validate_data + else: + from sklearn.utils.validation import validate_data + + return validate_data(X, y) + @Checks.verify_fit def fit(self, X, y, **fit_params): """ @@ -104,8 +113,7 @@ def fit(self, X, y, **fit_params): Returns the instance itself. """ - X, y = indexable(X, y) - X, y = self._validate_data(X, y) + X, y = self._check_data(X, y) fit_params = _check_method_params(X, params=fit_params) self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)