Skip to content

Commit

Permalink
Merge pull request #14 from soda-inria/features
Browse files Browse the repository at this point in the history
Features
  • Loading branch information
gaetanbrison authored Dec 19, 2024
2 parents 53e37ac + d4190d2 commit f7fccf9
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions carte_ai/src/carte_table_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
) # Import FeatureHasher from scikit-learn
from carte_ai.configs.directory import config_directory


def _create_edge_index(
num_nodes: int,
edge_attr: torch.Tensor,
Expand Down Expand Up @@ -66,7 +65,6 @@ def _create_edge_index(

return edge_index, edge_attr_


class Table2GraphTransformer(TransformerMixin, BaseEstimator):
"""
Transformer from tables to a list of graphs.
Expand Down Expand Up @@ -123,28 +121,24 @@ def fit(self, X, y=None):
if not hasattr(self, "lm_model_"):
self._load_lm_model()

cat_col_names = (
# Use original column names without lowercasing to avoid mismatches
self.cat_col_names = list(
X.select_dtypes(include="object")
.columns.str.replace("\n", " ", regex=True)
.str.lower()
)
self.cat_col_names = list(cat_col_names)
num_col_names = (
self.num_col_names = list(
X.select_dtypes(exclude="object")
.columns.str.replace("\n", " ", regex=True)
.str.lower()
)
self.num_col_names = list(num_col_names)
self.col_names = self.cat_col_names + self.num_col_names

self.num_transformer_ = PowerTransformer().set_output(transform="pandas")


# Ensure numerical columns exist before fitting the transformer
if self.num_col_names:
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
self.num_transformer_.fit(X[num_cols_exist])
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
self.num_transformer_.fit(X[num_cols_exist])
#print(f"Numerical columns fitted for normalization: {num_cols_exist}")

self.is_fitted_ = True
return self
Expand Down Expand Up @@ -186,10 +180,17 @@ def transform(self, X, y=None):
name_dict = {name: idx for idx, name in enumerate(names_total)}

name_attr_total = self._transform_names(names_total)
if self.num_col_names:
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
X_numerical = self._transform_numerical(X_numerical[num_cols_exist])

# Use the original numerical column names for transformation
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
X_numerical = self._transform_numerical(X[num_cols_exist])
#print(f"Transformed numerical columns: {X_numerical.head()}")
# Check mean and variance for each column
#for col in num_cols_exist:
# mean = X_numerical[col].mean()
# variance = X_numerical[col].var()
#print(f"Column: {col}, Mean: {mean:.6f}, Variance: {variance:.6f}")

data_graph = [
self._graph_construct(
Expand Down Expand Up @@ -223,7 +224,6 @@ def _load_lm_model(self):
if self.n_components != 300:
fasttext.util.reduce_model(self.lm_model_, self.n_components)


def _transform_numerical(self, X):
"""
Transform numerical columns using power transformer.
Expand Down

0 comments on commit f7fccf9

Please sign in to comment.