Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features #14

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions carte_ai/src/carte_table_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
) # Import FeatureHasher from scikit-learn
from carte_ai.configs.directory import config_directory


def _create_edge_index(
num_nodes: int,
edge_attr: torch.Tensor,
Expand Down Expand Up @@ -66,7 +65,6 @@ def _create_edge_index(

return edge_index, edge_attr_


class Table2GraphTransformer(TransformerMixin, BaseEstimator):
"""
Transformer from tables to a list of graphs.
Expand Down Expand Up @@ -123,28 +121,24 @@ def fit(self, X, y=None):
if not hasattr(self, "lm_model_"):
self._load_lm_model()

cat_col_names = (
# Use original column names without lowercasing to avoid mismatches
self.cat_col_names = list(
X.select_dtypes(include="object")
.columns.str.replace("\n", " ", regex=True)
.str.lower()
)
self.cat_col_names = list(cat_col_names)
num_col_names = (
self.num_col_names = list(
X.select_dtypes(exclude="object")
.columns.str.replace("\n", " ", regex=True)
.str.lower()
)
self.num_col_names = list(num_col_names)
self.col_names = self.cat_col_names + self.num_col_names

self.num_transformer_ = PowerTransformer().set_output(transform="pandas")


# Ensure numerical columns exist before fitting the transformer
if self.num_col_names:
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
self.num_transformer_.fit(X[num_cols_exist])
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
self.num_transformer_.fit(X[num_cols_exist])
#print(f"Numerical columns fitted for normalization: {num_cols_exist}")

self.is_fitted_ = True
return self
Expand Down Expand Up @@ -186,10 +180,17 @@ def transform(self, X, y=None):
name_dict = {name: idx for idx, name in enumerate(names_total)}

name_attr_total = self._transform_names(names_total)
if self.num_col_names:
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
X_numerical = self._transform_numerical(X_numerical[num_cols_exist])

# Use the original numerical column names for transformation
num_cols_exist = [col for col in self.num_col_names if col in X.columns]
if num_cols_exist:
X_numerical = self._transform_numerical(X[num_cols_exist])
#print(f"Transformed numerical columns: {X_numerical.head()}")
# Check mean and variance for each column
#for col in num_cols_exist:
# mean = X_numerical[col].mean()
# variance = X_numerical[col].var()
#print(f"Column: {col}, Mean: {mean:.6f}, Variance: {variance:.6f}")

data_graph = [
self._graph_construct(
Expand Down Expand Up @@ -223,7 +224,6 @@ def _load_lm_model(self):
if self.n_components != 300:
fasttext.util.reduce_model(self.lm_model_, self.n_components)


def _transform_numerical(self, X):
"""
Transform numerical columns using power transformer.
Expand Down
Loading