-
Notifications
You must be signed in to change notification settings - Fork 55
/
run_pipeline.py
166 lines (147 loc) · 4.81 KB
/
run_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import argparse
from datetime import date, datetime, timedelta
from pipelines.data_analysis_pipeline import data_analysis_pipeline
from pipelines.prediction_pipeline import inference_pipeline
from pipelines.training_pipeline import training_pipeline
from steps.analyzer import analyze_drift
from steps.discord_bot import discord_alert, discord_post_prediction
from steps.drift_reporter import evidently_drift_detector
from steps.encoder import data_encoder, encode_columns_and_clean
from steps.evaluator import tester
from steps.feature_engineer import feature_engineer
from steps.importer import (
game_data_importer_offline,
import_season_schedule_offline,
)
from steps.model_picker import model_picker
from steps.post_processor import data_post_processor
from steps.predictor import predictor
from steps.splitter import (
SklearnSplitterConfig,
SplitConfig,
TimeWindowConfig,
TrainingSplitConfig,
date_based_splitter,
get_coming_week_data,
reference_data_splitter,
sklearn_splitter,
)
from steps.trainer import random_forest_trainer
from utils.kubeflow_helper import get_kubeflow_settings
from zenml.config import DockerSettings
from zenml.integrations.constants import (
AWS,
EVIDENTLY,
KUBEFLOW,
MLFLOW,
SKLEARN,
)
from zenml.pipelines import Schedule
last_week = date.today() - timedelta(days=7)
LAST_DATA_DATE = "2022-04-10"
CURRY_FROM_DOWNTOWN = "2016-02-27"
docker_settings = DockerSettings(
required_integrations=[EVIDENTLY, SKLEARN, AWS, KUBEFLOW, MLFLOW],
requirements=["nba-api"],
)
def run_analysis():
"""Create an analysis pipeline run."""
# Initialize the pipeline
eda_pipeline = data_analysis_pipeline(
# importer=game_data_importer(),
importer=game_data_importer_offline(),
drift_splitter=date_based_splitter(
SplitConfig(date_split=CURRY_FROM_DOWNTOWN, columns=["FG3M"])
),
drift_detector=evidently_drift_detector,
drift_analyzer=analyze_drift(),
)
eda_pipeline.run(
config_path="data_analysis_config.yaml",
settings={
"orchestrator.kubeflow": get_kubeflow_settings(),
"docker": docker_settings,
},
)
def run_training(schedule: bool):
"""Create a training pipeline run.
Args:
schedule: If true, then run on the schedule.
"""
# Initialize the pipeline
p = training_pipeline(
# importer=game_data_importer(),
importer=game_data_importer_offline(),
# Train Model
feature_engineer=feature_engineer(),
encoder=data_encoder(),
ml_splitter=sklearn_splitter(
SklearnSplitterConfig(
ratios={"train": 0.6, "test": 0.2, "validation": 0.2}
)
),
trainer=random_forest_trainer(),
tester=tester(),
# Drift detection
drift_splitter=reference_data_splitter(
TrainingSplitConfig(
new_data_split_date=LAST_DATA_DATE,
start_reference_time_frame=CURRY_FROM_DOWNTOWN,
end_reference_time_frame="2019-02-27",
columns=["FG3M"],
)
),
drift_detector=evidently_drift_detector,
drift_alert=discord_alert(),
)
if schedule:
# Run with schedule
p.run(
schedule=Schedule(start_time=datetime.now(), interval_second=600),
)
else:
p.run(
settings={
"orchestrator.kubeflow": get_kubeflow_settings(),
"docker": docker_settings,
},
)
def run_inference():
"""Create an inference pipeline run."""
# Initialize the pipeline
inference_pipe = inference_pipeline(
# importer=import_season_schedule(
# SeasonScheduleConfig(current_season="2021-22")
# ),
importer=import_season_schedule_offline(),
preprocessor=encode_columns_and_clean(),
extract_next_week=get_coming_week_data(
TimeWindowConfig(time_window=7)
),
model_picker=model_picker(),
predictor=predictor(),
post_processor=data_post_processor(),
prediction_poster=discord_post_prediction(),
)
inference_pipe.run(
settings={
"orchestrator.kubeflow": get_kubeflow_settings(),
"docker": docker_settings,
}
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"pipeline", type=str, choices=["drift", "train", "infer"]
)
parser.add_argument("-s", "--schedule", type=bool)
args = parser.parse_args()
if args.pipeline == "drift":
run_analysis()
elif args.pipeline == "train":
run_training(args.schedule)
elif args.pipeline == "infer":
run_inference()