diff --git a/carte_ai/src/carte_table_to_graph.py b/carte_ai/src/carte_table_to_graph.py index 31fa40a..604bf5a 100644 --- a/carte_ai/src/carte_table_to_graph.py +++ b/carte_ai/src/carte_table_to_graph.py @@ -14,7 +14,6 @@ ) # Import FeatureHasher from scikit-learn from carte_ai.configs.directory import config_directory - def _create_edge_index( num_nodes: int, edge_attr: torch.Tensor, @@ -66,7 +65,6 @@ def _create_edge_index( return edge_index, edge_attr_ - class Table2GraphTransformer(TransformerMixin, BaseEstimator): """ Transformer from tables to a list of graphs. @@ -123,28 +121,24 @@ def fit(self, X, y=None): if not hasattr(self, "lm_model_"): self._load_lm_model() - cat_col_names = ( + # Use original column names without lowercasing to avoid mismatches + self.cat_col_names = list( X.select_dtypes(include="object") .columns.str.replace("\n", " ", regex=True) - .str.lower() ) - self.cat_col_names = list(cat_col_names) - num_col_names = ( + self.num_col_names = list( X.select_dtypes(exclude="object") .columns.str.replace("\n", " ", regex=True) - .str.lower() ) - self.num_col_names = list(num_col_names) self.col_names = self.cat_col_names + self.num_col_names self.num_transformer_ = PowerTransformer().set_output(transform="pandas") - # Ensure numerical columns exist before fitting the transformer - if self.num_col_names: - num_cols_exist = [col for col in self.num_col_names if col in X.columns] - if num_cols_exist: - self.num_transformer_.fit(X[num_cols_exist]) + num_cols_exist = [col for col in self.num_col_names if col in X.columns] + if num_cols_exist: + self.num_transformer_.fit(X[num_cols_exist]) + #print(f"Numerical columns fitted for normalization: {num_cols_exist}") self.is_fitted_ = True return self @@ -186,10 +180,17 @@ def transform(self, X, y=None): name_dict = {name: idx for idx, name in enumerate(names_total)} name_attr_total = self._transform_names(names_total) - if self.num_col_names: - num_cols_exist = [col for col in self.num_col_names if col in X.columns] - if num_cols_exist: - X_numerical = self._transform_numerical(X_numerical[num_cols_exist]) + + # Use the original numerical column names for transformation + num_cols_exist = [col for col in self.num_col_names if col in X.columns] + if num_cols_exist: + X_numerical = self._transform_numerical(X[num_cols_exist]) + #print(f"Transformed numerical columns: {X_numerical.head()}") + # Check mean and variance for each column + #for col in num_cols_exist: + # mean = X_numerical[col].mean() + # variance = X_numerical[col].var() + #print(f"Column: {col}, Mean: {mean:.6f}, Variance: {variance:.6f}") data_graph = [ self._graph_construct( @@ -223,7 +224,6 @@ def _load_lm_model(self): if self.n_components != 300: fasttext.util.reduce_model(self.lm_model_, self.n_components) - def _transform_numerical(self, X): """ Transform numerical columns using power transformer.