Skip to content

Commit

Permalink
removing minhash🧹
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanbrison committed Nov 4, 2024
1 parent 93c77d2 commit 86c94b2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 31 deletions.
31 changes: 1 addition & 30 deletions carte_ai/src/carte_table_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,6 @@ def fit(self, X, y=None):

self.num_transformer_ = PowerTransformer().set_output(transform="pandas")

if self.lm_model == "minhash":
self.name_transformer = make_pipeline(
FeatureHasher(n_features=self.n_components, input_type="string"),
PowerTransformer(),
)
# Collect all unique names to fit the PowerTransformer
names_total = self.col_names
# Transform names using FeatureHasher
hashed_features = self.name_transformer.named_steps[
"featurehasher"
].transform(names_total)
hashed_features_dense = hashed_features.toarray()
# Fit the PowerTransformer
self.name_transformer.named_steps["powertransformer"].fit(
hashed_features_dense
)

# Ensure numerical columns exist before fitting the transformer
if self.num_col_names:
Expand Down Expand Up @@ -238,9 +222,7 @@ def _load_lm_model(self):
self.lm_model_ = fasttext.load_model(self.fasttext_model_path)
if self.n_components != 300:
fasttext.util.reduce_model(self.lm_model_, self.n_components)
elif self.lm_model == "minhash":
# No need to load a model for FeatureHasher
pass


def _transform_numerical(self, X):
"""
Expand Down Expand Up @@ -277,17 +259,6 @@ def _transform_names(self, names_total):
[self.lm_model_.get_sentence_vector(name) for name in names_total],
dtype=np.float32,
)
elif self.lm_model == "minhash":
# Transform names using FeatureHasher
hashed_features = self.name_transformer.named_steps[
"featurehasher"
].transform(names_total)
hashed_features_dense = hashed_features.toarray()
# Apply PowerTransformer
transformed_features = self.name_transformer.named_steps[
"powertransformer"
].transform(hashed_features_dense)
return transformed_features.astype(np.float32)

def _graph_construct(self, data_cat, data_num, name_attr_total, name_dict, y, idx):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_src/test_carte_table_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_invalid_fasttext_model_path(dummy_data):
transformer = Table2GraphTransformer(lm_model="fasttext")
transformer.fit(dummy_data)

@pytest.mark.parametrize("lm_model", ["fasttext", "minhash"])
@pytest.mark.parametrize("lm_model", ["fasttext"])
def test_table_to_graph_with_different_lm_models(dummy_data, lm_model):
"""Test Table2GraphTransformer with different language models."""
transformer = Table2GraphTransformer(lm_model=lm_model, n_components=100, fasttext_model_path="path/to/fasttext.bin")
Expand Down

0 comments on commit 86c94b2

Please sign in to comment.