Skip to content

Commit

Permalink
enhance header cleanser module with multi-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
takuyagt committed Dec 3, 2024
1 parent 4171dfa commit b30c889
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 37 deletions.
25 changes: 25 additions & 0 deletions transforms/code/header_cleanser/kfp_ray/header_cleanser_wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,13 @@ def compute_exec_params_func(
runtime_job_id: str,
runtime_code_location: dict,
header_cleanser_contents_column_name: str,
header_cleanser_document_id_column_name: str,
header_cleanser_license: bool,
header_cleanser_copyright: bool,
header_cleanser_n_processes: int,
header_cleanser_tmp_dir: str,
header_cleanser_timeout: int,
header_cleanser_skip_timeout: bool,
) -> dict:
from runtime_utils import KFPUtils

Expand All @@ -56,8 +61,13 @@ def compute_exec_params_func(
"runtime_job_id": runtime_job_id,
"runtime_code_location": str(runtime_code_location),
"header_cleanser_contents_column_name": header_cleanser_contents_column_name,
"header_cleanser_document_id_column_name": header_cleanser_document_id_column_name,
"header_cleanser_license": header_cleanser_license,
"header_cleanser_copyright": header_cleanser_copyright,
"header_cleanser_n_processes": header_cleanser_n_processes,
"header_cleanser_tmp_dir": header_cleanser_tmp_dir,
"header_cleanser_timeout": header_cleanser_timeout,
"header_cleanser_skip_timeout": header_cleanser_skip_timeout,
}


Expand Down Expand Up @@ -119,8 +129,13 @@ def header_cleanser(
runtime_code_location: dict = {'github': 'github', 'commit_hash': '12345', 'path': 'path'},
# header cleanser parameters
header_cleanser_contents_column_name: str = "contents",
header_cleanser_document_id_column_name: str = "document_id",
header_cleanser_license: bool = True,
header_cleanser_copyright: bool = True,
header_cleanser_n_processes: int = 5,
header_cleanser_tmp_dir: str = "",
header_cleanser_timeout: int = 300,
header_cleanser_skip_timeout: bool = False,
# additional parameters
additional_params: str = '{"wait_interval": 2, "wait_cluster_ready_tmout": 800, "wait_cluster_up_tmout": 300, "wait_job_ready_tmout": 400, "wait_print_tmout": 30, "http_retries": 5, "delete_cluster_delay_minutes": 0}',
):
Expand Down Expand Up @@ -157,8 +172,13 @@ def header_cleanser(
:param runtime_actor_options - actor options
:param runtime_pipeline_id - pipeline id
:param contents_column_name - Name of the column holds the data to process
:param document_id_column_name - Name of the column holds the document id
:param license - Hold value true or false to delete/remove license or not.
:param copyright - Hold value true or false to delete/remove copyright or not.
:param n_processes - num processes to scan codes in parallel
:param tmp_dir - Path to tmp dir for codes
:param timeout - Value of timeout to scan codes
:param skip_timeout - Hold value true or false to skip removing copyright/header or not when scaning timeout.
:return: None
"""
# create clean_up task
Expand All @@ -177,8 +197,13 @@ def header_cleanser(
runtime_job_id=run_id,
runtime_code_location=runtime_code_location,
header_cleanser_contents_column_name=header_cleanser_contents_column_name,
header_cleanser_document_id_column_name=header_cleanser_document_id_column_name,
header_cleanser_license=header_cleanser_license,
header_cleanser_copyright=header_cleanser_copyright,
header_cleanser_n_processes=header_cleanser_n_processes,
header_cleanser_tmp_dir=header_cleanser_tmp_dir,
header_cleanser_timeout=header_cleanser_timeout,
header_cleanser_skip_timeout=header_cleanser_skip_timeout,
)

ComponentUtils.add_settings_to_component(compute_exec_params, ONE_HOUR_SEC * 2)
Expand Down
5 changes: 5 additions & 0 deletions transforms/code/header_cleanser/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ When running the transform with the Ray launcher (i.e. TransformLauncher),
the following command line arguments are available in addition to
the [python launcher](../../../../data-processing-lib/doc/python-launcher-options.md).
* --header_cleanser_contents_column_name - set the contents_column_name configuration key.
* --header_cleanser_document_id_column_name - set the document_id_column_name configuration key.
* --header_cleanser_license - set the license configuration key.
* --header_cleanser_copyright - set the copyright configuration key.
* --header_cleanser_n_processes - set the n_processes configuration key.
* --header_cleanser_tmp_dir - set the tmp_dir configuration key.
* --header_cleanser_timeout - set the timeout configuration key.
* --header_cleanser_skip_timeout - set the skip_timeout configuration key.

### Running the samples
To run the samples, use the following `make` targets
Expand Down
1 change: 1 addition & 0 deletions transforms/code/header_cleanser/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
data-prep-toolkit==0.2.3.dev0
scancode-toolkit==32.1.0 ; platform_system != 'Darwin'
timeout-timer==0.2.0

200 changes: 163 additions & 37 deletions transforms/code/header_cleanser/python/src/header_cleanser_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,73 @@

################################################################################

import concurrent.futures
from functools import partial
import logging
import math
import multiprocessing
import os
import tempfile
from argparse import ArgumentParser, Namespace
from typing import Any
import warnings

import pyarrow as pa
from data_processing.runtime.pure_python.runtime_configuration import (
PythonTransformRuntimeConfiguration,
)
from data_processing.transform import AbstractTableTransform, TransformConfiguration
from data_processing.utils import CLIArgumentProvider, get_logger, str2bool
import timeout_timer
from scancode import api


logger = get_logger(__name__)
logging.getLogger('bs4').setLevel(logging.ERROR)
logging.getLogger('timeout_timer').setLevel(logging.ERROR)
warnings.simplefilter('ignore', DeprecationWarning)

short_name = "header_cleanser"
cli_prefix = short_name + "_"
COLUMN_KEY = "contents_column_name"
DEFAULT_DOCUMENT_ID_COLUMN = "doc_id_column_name"
DOCUMENT_ID_COLUMN_KEY = "document_id_column_name"
LICENSE_KEY = "license"
COPYRIGHT_KEY = "copyright"
N_PROCESSES_KEY = "n_processes"
TMP_DIR_KEY = "tmp_dir"
TIMEOUT_KEY = "timeout"
SKIP_TIMEOUT_KEY = "skip_timeout"

column_cli_params = f"{cli_prefix}{COLUMN_KEY}"
document_id_column_cli_params = f"{cli_prefix}{DOCUMENT_ID_COLUMN_KEY}"
license_cli_params = f"{cli_prefix}{LICENSE_KEY}"
copyright_cli_params = f"{cli_prefix}{COPYRIGHT_KEY}"
n_processes_cli_params = f"{cli_prefix}{N_PROCESSES_KEY}"
tmp_dir_cli_params = f"{cli_prefix}{TMP_DIR_KEY}"
timeout_cli_params = f"{cli_prefix}{TIMEOUT_KEY}"
skip_timeout_cli_params = f"{cli_prefix}{SKIP_TIMEOUT_KEY}"

DEFAULT_COLUMN = "contents"
DEFAULT_DOCUMENT_ID_COLUMN = "document_id"
DEFAULT_LICENSE = True
DEFAULT_COPYRIGHT = True
DEFAULT_N_PROCESSES = 5
DEFAULT_TIMEOUT = 300
DEFAULT_SKIP_TIMEOUT = False
DEFAULT_CHUNK_SIZE = os.getenv("DEFAULT_CHUNK_SIZE", 100)


def file_generate(content):
def file_generate(content, tmp_dir=None):
"""
Generate temporary file so that it can be passed to scancode-toolkit.
"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", dir=tmp_dir) as temp_file:
temp_file.write(content.encode("utf-8"))
temp_file_path = temp_file.name
except Exception as e:
print(f"Failed to create file : {e}")
logger.error(f"Failed to create file : {e}")
return temp_file_path


Expand All @@ -70,14 +97,14 @@ def fetch_index(dict_data):
Extract License and copyright start and endline from dictonary
"""
ignore_lines = []
if dict_data.get("license_detections") != None:
if dict_data.get("license_detections", None) != None:
for licenses in dict_data.get("license_detections"):
for match in licenses.get("matches"):
start_line = match["start_line"] - 1
end_line = match["end_line"] - 1
ignore_lines.extend([i for i in range(start_line, end_line + 1)])

if dict_data.get("copyrights") != None:
if dict_data.get("copyrights", None) != None:
for copyrights in dict_data.get("copyrights"):
start_line = copyrights.get("start_line") - 1
end_line = copyrights.get("end_line") - 1
Expand Down Expand Up @@ -110,12 +137,21 @@ def check_empty_comment(code, ignore_lines):
return ignore_lines


def remove_copyright(code):
def remove_copyright(id_code: tuple[Any, str], tmp_dir=None, timeout=-1, skip_timeout=False):
"""
Using scancode.api function to detecte and remove copyright.
"""
file_path = file_generate(content=code)
copyright_dict = api.get_copyrights(file_path)
doc_id, code = id_code
file_path = file_generate(content=code, tmp_dir=tmp_dir)
try:
with timeout_timer.timeout(timeout, timer="signal"):
copyright_dict = api.get_copyrights(file_path)
except timeout_timer.TimeoutInterrupt:
if skip_timeout:
logger.warning(f"Skipping removing copyrights due to timeout: {doc_id}")
copyright_dict = {}
else:
raise Exception(f"Timeout during copyright scan: {doc_id}")
os.remove(file_path)
ignore_lines = fetch_index(copyright_dict)
if ignore_lines != []:
Expand All @@ -125,12 +161,21 @@ def remove_copyright(code):
return code, False


def remove_license(code):
def remove_license(id_code: tuple[Any, str], tmp_dir=None, timeout=-1, skip_timeout=False):
"""
Using scancode.api function to detecte and remove license.
"""
file_path = file_generate(content=code)
license_dict = api.get_licenses(file_path)
doc_id, code = id_code
file_path = file_generate(content=code, tmp_dir=tmp_dir)
try:
with timeout_timer.timeout(timeout, timer="signal"):
license_dict = api.get_licenses(file_path)
except timeout_timer.TimeoutInterrupt:
if skip_timeout:
logger.warning(f"Skipping removing licenses due to timeout: {doc_id}")
license_dict = {}
else:
raise Exception(f"Timeout during license scan: {doc_id}")
os.remove(file_path)
ignore_lines = fetch_index(license_dict)
if ignore_lines != []:
Expand All @@ -140,11 +185,27 @@ def remove_license(code):
return code, False


def remove_license_copyright(code):

file_path = file_generate(code)
copyright_dict = api.get_copyrights(file_path)
license_dict = api.get_licenses(file_path)
def remove_license_copyright(id_code: tuple[Any, str], tmp_dir=None, timeout=-1, skip_timeout=False):
doc_id, code = id_code
file_path = file_generate(code, tmp_dir=tmp_dir)
try:
with timeout_timer.timeout(timeout, timer="signal"):
copyright_dict = api.get_copyrights(file_path)
except timeout_timer.TimeoutInterrupt:
if skip_timeout:
logger.warning(f"Skipping removing copyrights due to timeout: {doc_id}")
copyright_dict = {}
else:
raise Exception(f"Timeout during copyright scan: {doc_id}")
try:
with timeout_timer.timeout(timeout, timer="signal"):
license_dict = api.get_licenses(file_path)
except timeout_timer.TimeoutInterrupt:
if skip_timeout:
logger.warning(f"Skipping removing licenses due to timeout: {doc_id}")
license_dict = {}
else:
raise Exception(f"Timeout during license scan: {doc_id}")
os.remove(file_path)
ignore_lines_license = fetch_index(license_dict)
ignore_lines_copyright = fetch_index(copyright_dict)
Expand All @@ -164,33 +225,63 @@ def __init__(self, config: dict):
self.column_name = config.get(COLUMN_KEY, DEFAULT_COLUMN)
self.license_remove = config.get(LICENSE_KEY, DEFAULT_LICENSE)
self.copyright_remove = config.get(COPYRIGHT_KEY, DEFAULT_COPYRIGHT)
self.document_id_column_name = config.get(DOCUMENT_ID_COLUMN_KEY, DEFAULT_DOCUMENT_ID_COLUMN)
n_processes = config.get(N_PROCESSES_KEY, DEFAULT_N_PROCESSES)
self.n_processes = (
max(1, multiprocessing.cpu_count() - 1)
if n_processes < 0 or n_processes > (multiprocessing.cpu_count() - 1)
else n_processes
)
logger.info(f"Running process: {self.n_processes}")
tmp_dir = config.get(TMP_DIR_KEY, None)
self.tmp_dir = tmp_dir if tmp_dir else None
if self.tmp_dir:
logger.info(f"Using for tmp dir: {self.tmp_dir}")
self.timeout = config.get(TIMEOUT_KEY, DEFAULT_TIMEOUT)
logger.info(f"Processing timeout: {self.timeout}")
self.skip_timeout = config.get(SKIP_TIMEOUT_KEY, DEFAULT_SKIP_TIMEOUT)
if self.skip_timeout:
logger.info("Skip processing records when timeout occurs")

def transform(self, table: pa.Table, file_name: str = None) -> tuple[list[pa.Table], dict]:

contents = table.column(self.column_name).to_pylist()
if self.document_id_column_name in table.column_names:
ids = table.column(self.document_id_column_name).to_pylist()
else:
ids = list(range(len(contents)))
ids_contents = list(zip(ids, contents))

if self.license_remove and self.copyright_remove:
f = remove_license_copyright

elif self.copyright_remove:
f = remove_copyright

elif self.license_remove:
f = remove_license

else:
return [table], {"Removed code count": 0}

func = partial(f, tmp_dir=self.tmp_dir, timeout=self.timeout, skip_timeout=self.skip_timeout)
updated_content = []
remove_code_count = 0
for content in contents:
if self.license_remove and self.copyright_remove:
new_content, detect = remove_license_copyright(content)
if detect:
remove_code_count += 1
updated_content.append(new_content)

elif self.copyright_remove:
new_content, detect = remove_copyright(content)
if detect:
remove_code_count += 1
updated_content.append(new_content)

elif self.license_remove:
new_content, detect = remove_license(content)
if detect:
remove_code_count += 1
updated_content.append(new_content)

else:
return [table], {"Removed code count": remove_code_count}
with concurrent.futures.ProcessPoolExecutor(max_workers=self.n_processes) as executor:
logger.debug(f"Start processing with {self.n_processes} executors")
chunksize = 1
if self.n_processes == 1:
chunksize = len(contents)
elif len(contents) > self.n_processes * DEFAULT_CHUNK_SIZE:
chunksize = DEFAULT_CHUNK_SIZE
elif len(contents) > self.n_processes * 2:
chunksize = len(contents) // self.n_processes
logger.debug(f"Breaking {len(contents)} contents into {math.ceil(len(contents) / chunksize)} chunks (size: {chunksize})")
results = executor.map(func, ids_contents, chunksize=chunksize)
for c, d in results:
updated_content.append(c)
remove_code_count += int(d)
logger.debug(f"End processing: {len(updated_content)} ({remove_code_count} removed)")

updated_content = pa.array(updated_content)

Expand Down Expand Up @@ -225,6 +316,41 @@ def add_input_params(self, parser: ArgumentParser) -> None:
default=f"{DEFAULT_COPYRIGHT}",
help="Set False if copyright should not be removed ",
)
parser.add_argument(
f"--{n_processes_cli_params}",
required=False,
type=int,
default=f"{DEFAULT_N_PROCESSES}",
help="Number of processes to scan codes in parallel",
)
parser.add_argument(
f"--{tmp_dir_cli_params}",
required=False,
type=str,
default=None,
help="Set a path if tmp directory should be specified",
)
parser.add_argument(
f"--{timeout_cli_params}",
required=False,
type=int,
default=f"{DEFAULT_TIMEOUT}",
help="Timeout in seconds for code scan",
)
parser.add_argument(
f"--{skip_timeout_cli_params}",
required=False,
type=lambda x: bool(str2bool(x)),
default=f"{DEFAULT_SKIP_TIMEOUT}",
help="Set True if records should be skipped when timeout occurrs during scanning",
)
parser.add_argument(
f"--{document_id_column_cli_params}",
required=False,
type=str,
default=f"{DEFAULT_DOCUMENT_ID_COLUMN}",
help="Name of the column holds the document id",
)

def apply_input_params(self, args: Namespace) -> bool:
captured = CLIArgumentProvider.capture_parameters(args, cli_prefix, False)
Expand Down
Loading

0 comments on commit b30c889

Please sign in to comment.