Skip to content

Commit

Permalink
adapt 'validate_data' future sklearn versions
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBlanke committed Oct 12, 2024
1 parent 4ee3c2b commit da02018
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/hyperactive/integrations/sklearn/hyperactive_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def _refit(self, X, y=None, **fit_params):
self.best_estimator_.fit(X, y, **fit_params)
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)

@Checks.verify_fit
def fit(self, X, y, **fit_params):
"""
Expand All @@ -104,8 +113,7 @@ def fit(self, X, y, **fit_params):
Returns the instance itself.
"""

X, y = indexable(X, y)
X, y = self._validate_data(X, y)
X, y = self._check_data(X, y)

fit_params = _check_method_params(X, params=fit_params)
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
Expand Down

0 comments on commit da02018

Please sign in to comment.