From 1a07ffa9111ac5ccea6ad527e8950b084aa4fc79 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 20 Aug 2024 19:33:12 -0700 Subject: [PATCH] fix(components): Use instance.target_field_name format for text-bison models only, use target_field_name for gemini models Signed-off-by: Googler PiperOrigin-RevId: 665638487 --- .../llm_evaluation_preprocessor/component.py | 6 ++---- .../evaluation_llm_text_generation_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py index f102fe541ac..f31f28823d7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py @@ -101,10 +101,8 @@ def evaluation_dataset_preprocessor_internal( f'--gcs_source_uris={gcs_source_uris}', f'--input_field_name={input_field_name}', f'--role_field_name={role_field_name}', - ( - f'--target_field_name={target_field_name}' - f'--model_name={model_name}' - ), + f'--target_field_name={target_field_name}', + f'--model_name={model_name}', f'--output_dirs={output_dirs}', '--executor_input={{$.json_escape[1]}}', ], diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py index 9b613ee8eb3..e9022932463 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py @@ -152,7 +152,7 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul project=project, location=location, evaluation_task=evaluation_task, - target_field_name=f'instance.{target_field_name}', + target_field_name=target_field_name, predictions_format=batch_predict_predictions_format, enable_row_based_metrics=enable_row_based_metrics, joined_predictions_gcs_source=batch_predict_task.outputs[