Skip to content

Commit

Permalink
feat(parse): update parse script
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyz committed May 4, 2024
1 parent 41d258e commit 43ef655
Showing 1 changed file with 38 additions and 33 deletions.
71 changes: 38 additions & 33 deletions script/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ def remove_trailing_comments(text):
# Return the text up to and including the last non-comment line, joined back into a single string
return '\n'.join(lines[:i+1])


def clean_data(text):
"""
Removes all comments and empty lines from the given text.
Parameters:
text (str): The text from which to remove trailing comment lines.
Returns:
str: The text with trailing comments removed.
"""
lines = text.splitlines() # Split the text into individual lines
new_lines = []
for l in lines:
if not l.strip().startswith('#') and l.strip():
new_lines.append(l)
return '\n'.join(new_lines)


def evaluate_test_class(code):
exec_globals = {}
exec(code, exec_globals)
Expand Down Expand Up @@ -237,47 +256,33 @@ def extract_test(file_contents, function_name):
except Exception as e:
return f"Error processing the script: {e}"

def get_test_case_names(code):
# Parse the code to AST
root = ast.parse(code)

# Define a class to collect test case names
class TestCaseNameCollector(ast.NodeVisitor):
def __init__(self):
self.test_case_names = []

def visit_FunctionDef(self, node):
if node.name.startswith("test"):
self.test_case_names.append(node.name)
self.generic_visit(node)

# Create a collector object and visit nodes
collector = TestCaseNameCollector()
collector.visit(root)

# Return the list of test case names found
return collector.test_case_names
def replace_pii(content):
for name in ["chien", "jenny", "wenhao", "niklas", "hanhu", "ratna", "ming", "junda", "haolan", "xiaohenng"]:
content = content.replace(name, "")
return content

def extract_content(file_path, rename_id=None):
data = {"task_id": file_path.split("/")[-1]}
with open(file_path, 'r') as file:
for line in file:
line = line.strip()
if line.startswith('def'):
# Extract the function name
start = line.find('def') + 4 # 'def ' has 4 characters
end = line.find('(')
if end != -1:
function_name = line[start:end].strip()
if function_name.startswith("f_"):
data["entry_point"] = function_name
break
lines = file.read().splitlines()
for line in lines:
line = line.strip()
if line.startswith('def'):
# Extract the function name
start = line.find('def') + 4 # 'def ' has 4 characters
end = line.find('(')
if end != -1:
function_name = line[start:end].strip()
if function_name.startswith("f_"):
data["entry_point"] = function_name
break
with open(file_path, "r", encoding="utf-8") as f:
if not rename_id:
rename_id = data["entry_point"]
data["entry_point"] = rename_id
content = f.read().strip("\n").replace("AxesSubplot", "Axes").replace("matplotlib.axes._subplots", "matplotlib.axes._axes")
content = content.replace(function_name, rename_id)
content = replace_pii(content)

function_name = rename_id
# Extracting the docstring if present
Expand Down Expand Up @@ -309,8 +314,8 @@ def extract_content(file_path, rename_id=None):
data["canonical_solution"] = "\n".join(lines[function_start_line:function_end_line])
else:
data["canonical_solution"] = ""
data["clean_canonical_solution"] = clean_data(data["canonical_solution"])
data["test"] = extract_test(content,function_name).strip()
data["test_case"] = get_test_case_names(data["test"])
data["apis"] = extract_apis(data["prompt"] + "\n" + data["canonical_solution"])
data["libs"] = list(set([api.split(".")[0] for api in data["apis"]]))
_, unused_imports = filter_unused_imports(data["prompt"], data["libs"])
Expand Down Expand Up @@ -401,7 +406,7 @@ def parse_docstring(docstring):
return sections

def reconstruct_problem(data):
return data["prompt"] + "\n" + data["canonical_solution"] + "\n\n" + data["test"] + "\n"
return data["prompt"] + "\n" + data["clean_canonical_solution"] + "\n\n" + data["test"] + "\n"

def get_instruction_prompt(data):
base = "Write a function called " + f'`{data["signature"]}` to: ' + " ".join(data["doc"]["description"])
Expand Down

0 comments on commit 43ef655

Please sign in to comment.