-
Notifications
You must be signed in to change notification settings - Fork 55
/
run_kubeflow_pipeline.py
47 lines (40 loc) · 1.45 KB
/
run_kubeflow_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
from pipelines.data_analysis_pipeline import data_analysis_pipeline
from pipelines.training_pipelines import training_pipeline
from steps.data_process import drop_cols, encode_cat_cols
from steps.data_splitter import data_splitter
from steps.evaluation import evaluation
from steps.ingest_data import ingest_data
from steps.trainer import model_trainer
from steps.visualizer import (
visualize_statistics,
visualize_train_test_statistics,
)
def analyze_pipeline():
"""Pipeline for analyzing data."""
analyze = data_analysis_pipeline(
ingest_data=ingest_data(),
data_splitter=data_splitter(),
)
analyze.run(config_path="analysis_pipeline_config.yaml")
visualize_statistics()
visualize_train_test_statistics()
def training_pipeline_run():
"""Pipeline for processing data."""
train_pipeline = training_pipeline(
ingest_data=ingest_data(),
encode_cat_cols=encode_cat_cols(),
drop_cols=drop_cols(),
data_splitter=data_splitter(),
model_trainer=model_trainer(),
evaluator=evaluation(),
)
train_pipeline.run(config_path="training_pipeline_config.yaml")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("pipeline", type=str, choices=["analyze", "train"])
args = parser.parse_args()
if args.pipeline == "analyze":
analyze_pipeline()
elif args.pipeline == "train":
training_pipeline_run()