Skip to content

Commit

Permalink
ast: Add new arguments to print_code_snippet
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Povišer <[email protected]>
  • Loading branch information
povik committed Feb 26, 2024
1 parent 348cb96 commit 431497d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 25 deletions.
40 changes: 22 additions & 18 deletions fold/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import string
import sys
import io
from contextlib import contextmanager


Expand All @@ -28,15 +29,15 @@ def markers_str(markers):
return "%s:%d:%d" % (a.name, a.line, a.col)


def print_line_span(lines, a, b, highlight=False):
def print_line_span(f, lines, a, b, highlight=False):
hl = '\033[38;5;126m|\033[0;0m' if highlight else ' '
for i in range(a - 1, b):
if i < 0 or i >= len(lines):
continue
print(" \033[2;37m{:4d}\033[0;0m{}{}".format(i + 1, hl, lines[i].rstrip()), file=sys.stderr)
print(" \033[2;37m{:4d}\033[0;0m{}{}".format(i + 1, hl, lines[i].rstrip()), file=f)


def print_code_snippet(markers):
def print_code_snippet(f, markers, inject_buffer=None):
if markers is None or (markers[0] is None and markers[1] is None):
return

Expand All @@ -48,34 +49,37 @@ def print_code_snippet(markers):
if markers[0].name != markers[1].name:
return

try:
lines = list(open(markers[0].name, 'r'))
except FileNotFoundError:
return
if inject_buffer is None:
try:
lines = list(open(markers[0].name, 'r'))
except FileNotFoundError:
return
else:
lines = list(io.StringIO(inject_buffer))

a, b = markers

print(file=sys.stderr)
print(file=f)
if a.line != b.line:
print_line_span(lines, a.line - 2, a.line - 1)
print_line_span(f, lines, a.line - 2, a.line - 1)
if b.line - a.line >= 8:
print_line_span(lines, a.line, a.line + 2, highlight=True)
print(" ...", file=sys.stderr)
print_line_span(lines, b.line - 2, b.line, highlight=True)
print_line_span(f, lines, a.line, a.line + 2, highlight=True)
print(" ...", file=f)
print_line_span(f, lines, b.line - 2, b.line, highlight=True)
else:
print_line_span(lines, a.line, b.line, highlight=True)
print_line_span(lines, b.line + 1, b.line + 2)
print_line_span(f, lines, a.line, b.line, highlight=True)
print_line_span(f, lines, b.line + 1, b.line + 2)
else:
if a.line > len(lines):
return
print_line_span(lines, a.line - 1, a.line)
print_line_span(f, lines, a.line - 1, a.line)
line = lines[a.line - 1]
print(" " + "".join([" " if (c != "\t") else "\t" for c in line[:a.col-1]])
+ '\033[38;5;126m'
+ ("~" * (max(b.col - a.col, 1)))
+ '\033[0;0m', file=sys.stderr)
print_line_span(lines, a.line + 1, a.line + 1)
print(file=sys.stderr)
+ '\033[0;0m', file=f)
print_line_span(f, lines, a.line + 1, a.line + 1)
print(file=f)


class _BadInputMessageFmt(string.Formatter):
Expand Down
4 changes: 2 additions & 2 deletions fold/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def on_Op(self, expr):
file=sys.stderr)
sys.exit(1)
except ast.BadInput as e:
ast.print_code_snippet(e.markers)
ast.print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)

Expand All @@ -287,6 +287,6 @@ def on_Op(self, expr):
if eval_(lhs) != eval_(rhs):
raise ast.BadInput("failed assertion")
except ast.BadInput as e:
ast.print_code_snippet(e.markers)
ast.print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)
6 changes: 3 additions & 3 deletions fold/logic/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def py_execute(self, f, filename, rawargs, design):
try:
top_ast_nodes = parse_spec_from_buffer(ys.read_istream(f), filename)
except BadInput as e:
print_code_snippet(e.markers)
print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)

Expand All @@ -70,7 +70,7 @@ def filter_nodes(typ):
d.read_constants(filter_nodes("const"))
d.impl_top_body(top_ast_nodes)
except BadInput as e:
print_code_snippet(e.markers)
print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)

Expand All @@ -81,7 +81,7 @@ def filter_nodes(typ):
d.rtl_module.set_top()
d.rtl_module.ym.fixup_ports()
except BadInput as e:
print_code_snippet(e.markers)
print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)
return
Expand Down
4 changes: 2 additions & 2 deletions fold/machinecode/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
file=sys.stderr)
sys.exit(1)
except ast.BadInput as e:
ast.print_code_snippet(e.markers)
ast.print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)

Expand All @@ -68,7 +68,7 @@ def filter_nodes(typ):
d.read_constants(filter_nodes("const"))
d.impl_top_body(top_ast_nodes)
except ast.BadInput as e:
ast.print_code_snippet(e.markers)
ast.print_code_snippet(sys.stderr, e.markers)
print(e, file=sys.stderr)
sys.exit(1)

Expand Down

0 comments on commit 431497d

Please sign in to comment.