From 2c1d5bf920038edd987f8d01d2717af8d7083139 Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Fri, 20 Dec 2024 18:05:52 +1100 Subject: [PATCH] fix: don't switch current working git branch when determining model changes (#789) **Reason for Change**: don't switch current working git branch when determining model changes. this behavior is unacceptable in e2e action steps. Signed-off-by: jerryzhuang --- .../kind-cluster/determine_models.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/.github/workflows/kind-cluster/determine_models.py b/.github/workflows/kind-cluster/determine_models.py index 92b3d7a6c..b92e10e5f 100644 --- a/.github/workflows/kind-cluster/determine_models.py +++ b/.github/workflows/kind-cluster/determine_models.py @@ -22,9 +22,14 @@ def read_yaml(file_path): # Format: {falcon-7b : {model_name:falcon-7b, type:text-generation, version: #, tag: #}} MODELS = {model['name']: model for model in YAML_PR['models']} KAITO_REPO_URL = "https://github.com/kaito-project/kaito.git" +GITREMOTE_TARGET = "_ciupstream" def set_multiline_output(name, value): - with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: + if not os.getenv('GITHUB_OUTPUT'): + print(f"Not in github env, skip writing to $GITHUB_OUTPUT .") + return + + with open(os.getenv('GITHUB_OUTPUT'), 'a') as fh: delimiter = uuid.uuid1() print(f'{name}<<{delimiter}', file=fh) print(value, file=fh) @@ -51,9 +56,11 @@ def run_command(command): def get_yaml_from_branch(branch, file_path): """Read YAML from a branch""" - subprocess.run(['git', 'fetch', 'origin', branch], check=True) - subprocess.run(['git', 'checkout', 'origin/' + branch], check=True) - return read_yaml(file_path) + subprocess.run(['git', 'fetch', GITREMOTE_TARGET, branch], check=True) + subprocess.run(['git', 'checkout', f"{GITREMOTE_TARGET}/" + branch], check=True) + content = read_yaml(file_path) + subprocess.run(['git', 'checkout', '-'], check=True) + return content def detect_changes_in_yaml(yaml_main, yaml_pr): """Detecting relevant changes in support_models.yaml""" @@ -90,33 +97,27 @@ def models_to_build(files_changed): seen_model_types.add(model_info["type"]) return list(models) -def check_modified_models(pr_branch): +def check_modified_models(): """Check for modified models in the repository.""" repo_dir = Path.cwd() / "repo" if repo_dir.exists(): shutil.rmtree(repo_dir) - run_command(f"git clone {KAITO_REPO_URL} {repo_dir}") - os.chdir(repo_dir) - - run_command("git checkout --detach") - run_command("git fetch origin main:main") - run_command(f"git fetch origin {pr_branch}:{pr_branch}") - run_command(f"git checkout {pr_branch}") + run_command(f"git remote add {GITREMOTE_TARGET} {KAITO_REPO_URL}") + run_command(f"git fetch {GITREMOTE_TARGET}") - files = run_command("git diff --name-only origin/main") # Returns each file on newline + files = run_command(f"git diff --name-only {GITREMOTE_TARGET}/main") # Returns each file on newline files = files.split("\n") - os.chdir(Path.cwd().parent) + print("Files Changed: ", files) modified_models = models_to_build(files) - + print("Modified Models (Images to build): ", modified_models) return modified_models def main(): - pr_branch = os.environ.get("PR_BRANCH", "main") # If not specified default to 'main' force_run_all = os.environ.get("FORCE_RUN_ALL", "false") # If not specified default to False force_run_all_phi = os.environ.get("FORCE_RUN_ALL_PHI", "false") # If not specified default to False force_run_all_public = os.environ.get("FORCE_RUN_ALL_PUBLIC", "false") # If not specified default to False @@ -131,7 +132,7 @@ def main(): else: # Logic to determine affected models # Example: affected_models = ['model1', 'model2', 'model3'] - affected_models = check_modified_models(pr_branch) + affected_models = check_modified_models() # Convert the list of models into JSON matrix format matrix = create_matrix(affected_models)