-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from soda-inria/data
New method to load data and new release 0.0.23
- Loading branch information
Showing
138 changed files
with
18,567 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from carte_ai.src import * | ||
from carte_ai.configs import * | ||
from carte_ai.data import * | ||
from carte_ai.scripts import * | ||
from .src import CARTERegressor, CARTEClassifier, Table2GraphTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from carte_ai.configs.carte_configs import * | ||
from carte_ai.configs.directory import * | ||
from carte_ai.configs.model_parameters import * | ||
from carte_ai.configs.visuailization import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
"""Specific configurations for the CARTE paper.""" | ||
|
||
## Dataset names | ||
carte_datalist = [ | ||
"anime_planet", | ||
"babies_r_us", | ||
"beer_ratings", | ||
"bikedekho", | ||
"bikewale", | ||
"buy_buy_baby", | ||
"cardekho", | ||
"chocolate_bar_ratings", | ||
"clear_corpus", | ||
"coffee_ratings", | ||
"company_employees", | ||
"employee_remuneration", | ||
"employee_salaries", | ||
"fifa22_players", | ||
"filmtv_movies", | ||
"journal_jcr", | ||
"journal_sjr", | ||
"jp_anime", | ||
"k_drama", | ||
"michelin", | ||
"mlds_salaries", | ||
"movies", | ||
"museums", | ||
"mydramalist", | ||
"nba_draft", | ||
"prescription_drugs", | ||
"ramen_ratings", | ||
"roger_ebert", | ||
"rotten_tomatoes", | ||
"spotify", | ||
"us_accidents_counts", | ||
"us_accidents_severity", | ||
"us_presidential", | ||
"used_cars_24", | ||
"used_cars_benz_italy", | ||
"used_cars_dot_com", | ||
"used_cars_pakistan", | ||
"used_cars_saudi_arabia", | ||
"videogame_sales", | ||
"whisky", | ||
"wikiliq_beer", | ||
"wikiliq_spirit", | ||
"wina_pl", | ||
"wine_dot_com_prices", | ||
"wine_dot_com_ratings", | ||
"wine_enthusiasts_prices", | ||
"wine_enthusiasts_ratings", | ||
"wine_vivino_price", | ||
"wine_vivino_rating", | ||
"yelp", | ||
"zomato", | ||
] | ||
|
||
## Dictionary of baseline methods | ||
carte_singletable_baselines = dict() | ||
carte_singletable_baselines["full"] = [ | ||
"carte-gnn", | ||
"catboost", | ||
"sentence-llm-concat-num_histgb", | ||
"sentence-llm-concat-num_xgb", | ||
"sentence-llm-embed-num_histgb", | ||
"sentence-llm-embed-num_xgb", | ||
"tablevectorizer-fasttext_histgb", | ||
"tablevectorizer-fasttext_xgb", | ||
"tablevectorizer-llm_histgb", | ||
"tablevectorizer-llm_xgb", | ||
"tablevectorizer_histgb", | ||
"tablevectorizer_logistic", | ||
"tablevectorizer_mlp", | ||
"tablevectorizer_randomforest", | ||
"tablevectorizer_resnet", | ||
"tablevectorizer_ridge", | ||
"tablevectorizer_xgb", | ||
"tablevectorizer_tabpfn", | ||
"target-encoder_histgb", | ||
"target-encoder_logistic", | ||
"target-encoder_mlp", | ||
"target-encoder_randomforest", | ||
"target-encoder_resnet", | ||
"target-encoder_ridge", | ||
"target-encoder_xgb", | ||
"target-encoder_tabpfn", | ||
] | ||
|
||
carte_singletable_baselines["reduced"] = [ | ||
"carte-gnn", | ||
"catboost", | ||
"sentence-llm-concat-num_xgb", | ||
"sentence-llm-embed-num_xgb", | ||
"tablevectorizer_logistic", | ||
"tablevectorizer_mlp", | ||
"tablevectorizer_randomforest", | ||
"tablevectorizer_resnet", | ||
"tablevectorizer_ridge", | ||
"tablevectorizer_xgb", | ||
"target-encoder_tabpfn", | ||
] | ||
|
||
carte_multitable_baselines = [ | ||
"original_carte-multitable", | ||
"matched_carte-multitable", | ||
"original_catboost-multitable", | ||
"matched_catboost-multitable", | ||
"original-sentence-llm_histgb-multitable", | ||
"matched-sentence-llm_histgb-multitable", | ||
] | ||
|
||
|
||
## Dictionary of method mapping | ||
carte_singletable_baseline_mapping = dict() | ||
carte_singletable_baseline_mapping["carte-gnn"] = "CARTE" | ||
|
||
# Preprocessings | ||
carte_singletable_baseline_mapping["tablevectorizer_"] = "TabVec-" | ||
carte_singletable_baseline_mapping["tablevectorizer-"] = "TabVec-" | ||
carte_singletable_baseline_mapping["target-encoder_"] = "TarEnc-" | ||
carte_singletable_baseline_mapping["fasttext_"] = "FT-" | ||
carte_singletable_baseline_mapping["llm_"] = "LLM-" | ||
carte_singletable_baseline_mapping["sentence-llm-concat-num_"] = "S-LLM-CN-" | ||
carte_singletable_baseline_mapping["sentence-llm-embed-num_"] = "S-LLM-EN-" | ||
|
||
# Estimators | ||
carte_singletable_baseline_mapping["catboost"] = "CatBoost" | ||
carte_singletable_baseline_mapping["xgb"] = "XGB" | ||
carte_singletable_baseline_mapping["histgb"] = "HGB" | ||
carte_singletable_baseline_mapping["randomforest"] = "RF" | ||
carte_singletable_baseline_mapping["ridge"] = "Ridge" | ||
carte_singletable_baseline_mapping["logistic"] = "Logistic" | ||
carte_singletable_baseline_mapping["mlp"] = "MLP" | ||
carte_singletable_baseline_mapping["resnet"] = "ResNet" | ||
carte_singletable_baseline_mapping["tabpfn"] = "TabPFN" | ||
|
||
# Bagging | ||
carte_singletable_baseline_mapping["bagging"] = "Bagging" | ||
|
||
## Colors for visualization | ||
carte_singletable_color_palette = dict() | ||
carte_singletable_color_palette["CARTE"] = "C3" | ||
carte_singletable_color_palette["CatBoost"] = "C0" | ||
carte_singletable_color_palette["TabVec-XGB"] = "C1" | ||
carte_singletable_color_palette["TabVec-RF"] = "C2" | ||
carte_singletable_color_palette["TabVec-Ridge"] = "C4" | ||
carte_singletable_color_palette["TabVec-Logistic"] = "C5" | ||
carte_singletable_color_palette["S-LLM-CN-XGB"] = "C6" | ||
carte_singletable_color_palette["S-LLM-EN-XGB"] = "C7" | ||
carte_singletable_color_palette["TabVec-ResNet"] = "C8" | ||
carte_singletable_color_palette["TabVec-MLP"] = "C9" | ||
carte_singletable_color_palette["TarEnc-TabPFN"] = "#A9561E" | ||
|
||
## Markers for visualization | ||
carte_singletable_markers = dict() | ||
carte_singletable_markers["CARTE"] = "o" | ||
carte_singletable_markers["TabVec-XGB"] = (4, 0, 45) | ||
carte_singletable_markers["TabVec-RF"] = "P" | ||
carte_singletable_markers["CatBoost"] = "X" | ||
carte_singletable_markers["S-LLM-CN-XGB"] = (4, 0, 0) | ||
carte_singletable_markers["S-LLM-EN-XGB"] = "d" | ||
carte_singletable_markers["TabVec-Ridge"] = "v" | ||
carte_singletable_markers["TabVec-Logistic"] = "v" | ||
carte_singletable_markers["TabVec-ResNet"] = "^" | ||
carte_singletable_markers["TabVec-MLP"] = "p" | ||
carte_singletable_markers["TarEnc-TabPFN"] = (5, 1, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from pathlib import Path | ||
|
||
# Get the base path relative to this file's location | ||
base_path = Path(__file__).resolve().parent.parent # This gives '/home/infres/gbrison/carte/carte_ai' | ||
|
||
config_directory = dict() | ||
config_directory["base_path"] = base_path | ||
|
||
config_directory["data"] = str(base_path / "data/") | ||
config_directory["pretrained_model"] = str(base_path / "data/etc/kg_pretrained.pt") # Correct path | ||
config_directory["data_raw"] = str(base_path / "data/data_raw/") | ||
config_directory["data_singletable"] = str(base_path / "data/data_singletable/") | ||
config_directory["data_yago"] = str(base_path / "data/data_yago/") | ||
config_directory["etc"] = str(base_path / "data/etc/") | ||
|
||
config_directory["results"] = str(base_path / "results/") | ||
config_directory["compiled_results"] = str(base_path / "results/compiled_results/") | ||
config_directory["visualization"] = str(base_path / "visualization/") | ||
|
||
# Specify the directory in which you have downloaded each | ||
config_directory["fasttext"] = str(base_path / "data/etc/cc.en.300.bin") | ||
config_directory["ken_embedding"] = str(base_path / "data/etc/ken_embedding.parquet") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
Parameter distributions for hyperparameter optimization | ||
""" | ||
|
||
import numpy as np | ||
from scipy.stats import loguniform, randint, uniform, norm | ||
import copy | ||
|
||
|
||
class loguniform_int: | ||
"""Integer valued version of the log-uniform distribution""" | ||
|
||
def __init__(self, a, b): | ||
self._distribution = loguniform(a, b) | ||
|
||
def rvs(self, *args, **kwargs): | ||
"""Random variable sample""" | ||
return self._distribution.rvs(*args, **kwargs).astype(int) | ||
|
||
|
||
class norm_int: | ||
"""Integer valued version of the normal distribution""" | ||
|
||
def __init__(self, a, b): | ||
self._distribution = norm(a, b) | ||
|
||
def rvs(self, *args, **kwargs): | ||
"""Random variable sample""" | ||
if self._distribution.rvs(*args, **kwargs).astype(int) < 1: | ||
return 1 | ||
else: | ||
return self._distribution.rvs(*args, **kwargs).astype(int) | ||
|
||
|
||
param_distributions_total = dict() | ||
|
||
# carte-gnn | ||
param_distributions = dict() | ||
lr_grid = [1e-4, 2.5e-4, 5e-4, 7.5e-4, 1e-3] | ||
param_distributions["learning_rate"] = lr_grid | ||
param_distributions_total["carte-gnn"] = param_distributions | ||
|
||
# histgb | ||
param_distributions = dict() | ||
param_distributions["learning_rate"] = loguniform(1e-2, 10) | ||
param_distributions["max_depth"] = [None, 2, 3, 4] | ||
param_distributions["max_leaf_nodes"] = norm_int(31, 5) | ||
param_distributions["min_samples_leaf"] = norm_int(20, 2) | ||
param_distributions["l2_regularization"] = loguniform(1e-6, 1e3) | ||
param_distributions_total["histgb"] = param_distributions | ||
|
||
# catboost | ||
param_distributions = dict() | ||
param_distributions["max_depth"] = randint(2, 11) | ||
param_distributions["learning_rate"] = loguniform(1e-5, 1) | ||
param_distributions["bagging_temperature"] = uniform(0, 1) | ||
param_distributions["l2_leaf_reg"] = loguniform(1, 10) | ||
param_distributions["iterations"] = randint(400, 1001) | ||
param_distributions["one_hot_max_size"] = randint(2, 26) | ||
param_distributions_total["catboost"] = param_distributions | ||
|
||
# xgb | ||
param_distributions = dict() | ||
param_distributions["n_estimators"] = randint(50, 1001) | ||
param_distributions["max_depth"] = randint(2, 11) | ||
param_distributions["min_child_weight"] = loguniform(1, 100) | ||
param_distributions["subsample"] = uniform(0.5, 1 - 0.5) | ||
param_distributions["learning_rate"] = loguniform(1e-5, 1) | ||
param_distributions["colsample_bylevel"] = uniform(0.5, 1 - 0.5) | ||
param_distributions["colsample_bytree"] = uniform(0.5, 1 - 0.5) | ||
param_distributions["gamma"] = loguniform(1e-8, 7) | ||
param_distributions["lambda"] = loguniform(1, 4) | ||
param_distributions["alpha"] = loguniform(1e-8, 100) | ||
param_distributions_total["xgb"] = param_distributions | ||
|
||
# RandomForest | ||
param_distributions = dict() | ||
param_distributions["n_estimators"] = randint(50, 250) | ||
param_distributions["max_depth"] = [None, 2, 3, 4] | ||
param_distributions["max_features"] = [ | ||
"sqrt", | ||
"log2", | ||
None, | ||
0.1, | ||
0.2, | ||
0.3, | ||
0.4, | ||
0.5, | ||
0.6, | ||
0.7, | ||
0.8, | ||
0.9, | ||
] | ||
param_distributions["min_samples_leaf"] = loguniform_int(0.5, 50.5) | ||
param_distributions["bootstrap"] = [True, False] | ||
param_distributions["min_impurity_decrease"] = [0.0, 0.01, 0.02, 0.05] | ||
param_distributions_total["randomforest"] = param_distributions | ||
|
||
|
||
# resnet | ||
param_distributions = dict() | ||
param_distributions["normalization"] = ["batchnorm", "layernorm"] | ||
param_distributions["num_layers"] = randint(1, 9) | ||
param_distributions["hidden_dim"] = randint(32, 513) | ||
param_distributions["hidden_factor"] = randint(1, 3) | ||
param_distributions["hidden_dropout_prob"] = uniform(0.0, 0.5) | ||
param_distributions["residual_dropout_prob"] = uniform(0.0, 0.5) | ||
param_distributions["learning_rate"] = loguniform(1e-5, 1e-2) | ||
param_distributions["weight_decay"] = loguniform(1e-8, 1e-2) | ||
param_distributions["batch_size"] = [16, 32] | ||
param_distributions_total["resnet"] = param_distributions | ||
|
||
# mlp | ||
param_distributions = dict() | ||
param_distributions["hidden_dim"] = [2**x for x in range(4, 11)] | ||
param_distributions["num_layers"] = randint(1, 5) | ||
param_distributions["dropout_prob"] = uniform(0.0, 0.5) | ||
param_distributions["learning_rate"] = loguniform(1e-5, 1e-2) | ||
param_distributions["weight_decay"] = loguniform(1e-8, 1e-2) | ||
param_distributions["batch_size"] = [16, 32] | ||
param_distributions_total["mlp"] = param_distributions | ||
|
||
# ridge regression | ||
param_distributions = dict() | ||
param_distributions["solver"] = ["svd", "cholesky", "lsqr", "sag"] | ||
param_distributions["alpha"] = loguniform(1e-5, 100) | ||
param_distributions_total["ridge"] = param_distributions | ||
|
||
# logistic regression | ||
param_distributions = dict() | ||
param_distributions["solver"] = ["newton-cg", "lbfgs", "liblinear"] | ||
param_distributions["penalty"] = ["none", "l1", "l2", "elasticnet"] | ||
param_distributions["C"] = loguniform(1e-5, 100) | ||
param_distributions_total["logistic"] = param_distributions | ||
|
||
# tabpfn | ||
param_distributions = dict() | ||
param_distributions_total["tabpfn"] = param_distributions | ||
|
||
# catboost-multitable | ||
param_distributions = copy.deepcopy(param_distributions_total["catboost"]) | ||
param_distributions["source_fraction"] = uniform(0, 1) | ||
param_distributions_total["catboost-multitable"] = param_distributions | ||
|
||
# histgb-multitable | ||
param_distributions = copy.deepcopy(param_distributions_total["histgb"]) | ||
param_distributions["source_fraction"] = uniform(0, 1) | ||
param_distributions_total["histgb-multitable"] = param_distributions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
""" | ||
Visualization configurations | ||
""" | ||
|
||
# Main models | ||
model_color_palette = dict() | ||
model_color_palette["CARTE"] = "C3" | ||
model_color_palette["CatBoost"] = "C0" | ||
model_color_palette["TabVec-XGB"] = "C1" | ||
model_color_palette["TabVec-RF"] = "C2" | ||
model_color_palette["TabVec-Ridge"] = "C4" | ||
model_color_palette["TabVec-Logistic"] = "C5" | ||
model_color_palette["S-LLM-CN-XGB"] = "C6" # "" | ||
model_color_palette["S-LLM-EN-XGB"] = "C7" # "C7" "#C875C4" mediumorchid | ||
model_color_palette["ResNet"] = "C8" | ||
model_color_palette["MLP"] = "C9" | ||
model_color_palette["TabPFN"] = "#A9561E" | ||
|
||
model_color_palette["TabVec-RandomForest"] = "C2" | ||
model_color_palette["TabVec-ResNet"] = "C8" | ||
model_color_palette["TabVec-MLP"] = "C9" | ||
model_color_palette["TarEnc-TabPFN"] = "#A9561E" | ||
|
||
|
||
# model_color_palette["CARTE-B"] = "C3" | ||
# model_color_palette["CatBoost-B"] = "C0" | ||
# model_color_palette["TabVec-XGB-B"] = "C1" | ||
# model_color_palette["TabVec-RF-B"] = "C2" | ||
# model_color_palette["TabVec-Ridge-B"] = "C4" | ||
# model_color_palette["TabVec-Logistic-B"] = "C5" | ||
# model_color_palette["S-LLM-CN-XGB-B"] = "C6" | ||
# model_color_palette["S-LLM-EN-XGB-B"] = "C7" | ||
# model_color_palette["ResNet-B"] = "C8" | ||
# model_color_palette["MLP-B"] = "C9" | ||
# model_color_palette["TabPFN-B"] = "#A9561E" | ||
|
||
|
||
# model_color_palette["TabVec-HGB"] = "#650021" | ||
# model_color_palette["TabVec-TabPFN"] = "#650021" | ||
# model_color_palette["TabVec-FT-XGB"] = "#650021" | ||
# model_color_palette["TabVec-FT-HGB"] = "#650021" | ||
|
||
# model_color_palette["TabLLM"] = "#653700" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from carte_ai.data.load_data import * |
Oops, something went wrong.