diff --git a/carte_ai/src/carte_table_to_graph.py b/carte_ai/src/carte_table_to_graph.py index d17e2ff..31fa40a 100644 --- a/carte_ai/src/carte_table_to_graph.py +++ b/carte_ai/src/carte_table_to_graph.py @@ -139,22 +139,6 @@ def fit(self, X, y=None): self.num_transformer_ = PowerTransformer().set_output(transform="pandas") - if self.lm_model == "minhash": - self.name_transformer = make_pipeline( - FeatureHasher(n_features=self.n_components, input_type="string"), - PowerTransformer(), - ) - # Collect all unique names to fit the PowerTransformer - names_total = self.col_names - # Transform names using FeatureHasher - hashed_features = self.name_transformer.named_steps[ - "featurehasher" - ].transform(names_total) - hashed_features_dense = hashed_features.toarray() - # Fit the PowerTransformer - self.name_transformer.named_steps["powertransformer"].fit( - hashed_features_dense - ) # Ensure numerical columns exist before fitting the transformer if self.num_col_names: @@ -238,9 +222,7 @@ def _load_lm_model(self): self.lm_model_ = fasttext.load_model(self.fasttext_model_path) if self.n_components != 300: fasttext.util.reduce_model(self.lm_model_, self.n_components) - elif self.lm_model == "minhash": - # No need to load a model for FeatureHasher - pass + def _transform_numerical(self, X): """ @@ -277,17 +259,6 @@ def _transform_names(self, names_total): [self.lm_model_.get_sentence_vector(name) for name in names_total], dtype=np.float32, ) - elif self.lm_model == "minhash": - # Transform names using FeatureHasher - hashed_features = self.name_transformer.named_steps[ - "featurehasher" - ].transform(names_total) - hashed_features_dense = hashed_features.toarray() - # Apply PowerTransformer - transformed_features = self.name_transformer.named_steps[ - "powertransformer" - ].transform(hashed_features_dense) - return transformed_features.astype(np.float32) def _graph_construct(self, data_cat, data_num, name_attr_total, name_dict, y, idx): """ diff --git a/tests/tests_src/test_carte_table_to_graph.py b/tests/tests_src/test_carte_table_to_graph.py index a00221d..c6ff423 100644 --- a/tests/tests_src/test_carte_table_to_graph.py +++ b/tests/tests_src/test_carte_table_to_graph.py @@ -62,7 +62,7 @@ def test_invalid_fasttext_model_path(dummy_data): transformer = Table2GraphTransformer(lm_model="fasttext") transformer.fit(dummy_data) -@pytest.mark.parametrize("lm_model", ["fasttext", "minhash"]) +@pytest.mark.parametrize("lm_model", ["fasttext"]) def test_table_to_graph_with_different_lm_models(dummy_data, lm_model): """Test Table2GraphTransformer with different language models.""" transformer = Table2GraphTransformer(lm_model=lm_model, n_components=100, fasttext_model_path="path/to/fasttext.bin")