diff --git a/evadb/executor/drop_object_executor.py b/evadb/executor/drop_object_executor.py index c4f108052..d27012328 100644 --- a/evadb/executor/drop_object_executor.py +++ b/evadb/executor/drop_object_executor.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil import pandas as pd +from evadb.configuration.constants import TRAINING_FRAMEWORKS from evadb.database import EvaDBDatabase from evadb.executor.abstract_executor import AbstractExecutor from evadb.executor.executor_utils import ExecutorError, handle_vector_store_params @@ -24,6 +27,7 @@ from evadb.plan_nodes.drop_object_plan import DropObjectPlan from evadb.storage.storage_engine import StorageEngine from evadb.third_party.vector_stores.utils import VectorStoreFactory +from evadb.utils.generic_utils import string_comparison_case_insensitive from evadb.utils.logging_manager import logger @@ -94,19 +98,58 @@ def _handle_drop_function(self, function_name: str, if_exists: bool): function_entry = self.catalog().get_function_catalog_entry_by_name( function_name ) - for cache in function_entry.dep_caches: - self.catalog().drop_function_cache_catalog_entry(cache) - - # todo also delete the indexes associated with the table - - self.catalog().delete_function_catalog_entry_by_name(function_name) - - return Batch( - pd.DataFrame( - {f"Function {function_name} successfully dropped"}, - index=[0], + # training framework model cleanup on drop function + err_msg = ( + f"Error removing {function_entry.type} model for function {function_name}." + ) + try: + if function_entry.type.lower() in [x.lower() for x in TRAINING_FRAMEWORKS]: + filtered_metadata = list( + filter(lambda x: x.key == "model_path", function_entry.metadata) ) + if len(filtered_metadata) > 0: + model_path = os.path.abspath(filtered_metadata[0].value) + """For 'Forecasting' the entire function catalog of forecasting functions + is checked to see if the model path is shared""" + if string_comparison_case_insensitive( + function_entry.type, "Forecasting" + ): + forecasting_function_entries = ( + self.catalog().get_function_catalog_entries_by_type( + function_entry.type + ) + ) + functions_using_same_model = sum( + 1 + for entry in forecasting_function_entries + if any( + x.key == "model_path" + and os.path.abspath(x.value) == model_path + for x in entry.metadata + ) + ) + if functions_using_same_model == 1: + dir_path = os.path.abspath(os.path.dirname(model_path)) + if os.path.exists(dir_path): + shutil.rmtree(dir_path) + else: + if os.path.exists(model_path): + os.remove(model_path) + + except Exception as e: + raise RuntimeError(f"{err_msg}\n{e}") + + for cache in function_entry.dep_caches: + self.catalog().drop_function_cache_catalog_entry(cache) + + # todo also delete the indexes associated with the table + self.catalog().delete_function_catalog_entry_by_name(function_name) + return Batch( + pd.DataFrame( + {f"Function {function_name} successfully dropped"}, + index=[0], ) + ) def _handle_drop_index(self, index_name: str, if_exists: bool): index_obj = self.catalog().get_index_catalog_entry_by_name(index_name)