Skip to content

Commit

Permalink
Rust: make File usable in codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
Paolo Tranquilli committed Dec 2, 2024
1 parent 7e0e5a3 commit b57a374
Show file tree
Hide file tree
Showing 40 changed files with 363 additions and 141 deletions.
3 changes: 2 additions & 1 deletion misc/codegen/generators/dbschemegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a

def get_declarations(data: schema.Schema):
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)]
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
data.classes, add_or_none_except)]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties
Expand Down
35 changes: 26 additions & 9 deletions misc/codegen/generators/qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
return f"{prop_name} of this {class_name}"


def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.Class],
def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> bool:
if t in lookup:
match lookup[t]:
case schema.Class() as cls:
return "ql_hideable" in cls.pragmas
return False


def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
prev_child: str = "") -> ql.Property:

args = dict(
type=prop.type if not prop.is_predicate else "predicate",
qltest_skip="qltest_skip" in prop.pragmas,
Expand All @@ -115,7 +124,8 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
is_unordered=prop.is_unordered,
description=prop.description,
synth=bool(cls.synth) or prop.synth,
type_is_hideable="ql_hideable" in lookup[prop.type].pragmas if prop.type in lookup else False,
type_is_hideable=_type_is_hideable(prop.type, lookup),
type_is_codegen_class=prop.type in lookup and not lookup[prop.type].imported,
internal="ql_internal" in prop.pragmas,
)
ql_name = prop.pragmas.get("ql_name", prop.name)
Expand Down Expand Up @@ -154,7 +164,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
return ql.Property(**args)


def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> ql.Class:
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
if "ql_name" in cls.pragmas:
raise Error("ql_name is not supported yet for classes, only for properties")
prev_child = ""
Expand Down Expand Up @@ -391,14 +401,15 @@ def generate(opts, renderer):

data = schemaloader.load_file(input)

classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
if not classes:
raise NoClasses
root = next(iter(classes.values()))
if root.has_children:
raise RootElementHasChildren(root)

imports = {}
pre_imports = {n: cls.module for n, cls in data.classes.items() if cls.imported}
imports = dict(pre_imports)
imports_impl = {}
classes_used_by = {}
cfg_classes = []
Expand All @@ -410,7 +421,7 @@ def generate(opts, renderer):
force=opts.force) as renderer:

db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
renderer.render(ql.DbClasses(db_classes), out / "Raw.qll")
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")

classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
for c in classes_by_dir_and_name:
Expand Down Expand Up @@ -439,6 +450,8 @@ def generate(opts, renderer):
renderer.render(cfg_classes_val, cfg_qll)

for c in data.classes.values():
if c.imported:
continue
path = _get_path(c)
path_impl = _get_path_impl(c)
stub_file = stub_out / path_impl
Expand All @@ -457,20 +470,23 @@ def generate(opts, renderer):
renderer.render(class_public, class_public_file)

# for example path/to/elements -> path/to/elements.qll
renderer.render(ql.ImportList([i for name, i in imports.items() if not classes[name].internal]),
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
include_file)

elements_module = get_import(include_file, opts.root_dir)

renderer.render(
ql.GetParentImplementation(
classes=list(classes.values()),
imports=[elements_module] + [i for name, i in imports.items() if classes[name].internal],
imports=[elements_module] + [i for name,
i in imports.items() if name in classes and classes[name].internal],
),
out / 'ParentChild.qll')

if test_out:
for c in data.classes.values():
if c.imported:
continue
if should_skip_qltest(c, data.classes):
continue
test_with_name = c.pragmas.get("qltest_test_with")
Expand Down Expand Up @@ -500,7 +516,8 @@ def generate(opts, renderer):
constructor_imports = []
synth_constructor_imports = []
stubs = {}
for cls in sorted(data.classes.values(), key=lambda cls: (cls.group, cls.name)):
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
key=lambda cls: (cls.group, cls.name)):
synth_type = get_ql_synth_class(cls)
if synth_type.is_final:
final_synth_types.append(synth_type)
Expand Down
16 changes: 10 additions & 6 deletions misc/codegen/generators/rustgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:


def _get_properties(
cls: schema.Class, lookup: dict[str, schema.Class],
cls: schema.Class, lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
Expand All @@ -58,20 +58,22 @@ def _get_properties(


def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.Class]
cls: schema.Class, lookup: dict[str, schema.ClassBase]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
yield base
yield from _get_ancestors(base, lookup)
if not base.imported:
base = typing.cast(schema.Class, base)
yield base
yield from _get_ancestors(base, lookup)


class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes

def _get_class(self, name: str) -> rust.Class:
cls = self._classmap[name]
cls = typing.cast(schema.Class, self._classmap[name])
properties = [
(c, p)
for c, p in _get_properties(cls, self._classmap)
Expand Down Expand Up @@ -101,8 +103,10 @@ def _get_class(self, name: str) -> rust.Class:
def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.synth:
if not cls.imported and not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
elif cls.imported:
ret[""].append(rust.Class(name=cls.name))
return ret


Expand Down
2 changes: 2 additions & 0 deletions misc/codegen/generators/rusttestgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def generate(opts, renderer):
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force) as renderer:
for cls in schema.classes.values():
if cls.imported:
continue
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_doc_test" in cls.pragmas):
continue
Expand Down
6 changes: 2 additions & 4 deletions misc/codegen/lib/ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Property:
doc_plural: Optional[str] = None
synth: bool = False
type_is_hideable: bool = False
type_is_codegen_class: bool = False
internal: bool = False
cfg: bool = False

Expand All @@ -66,10 +67,6 @@ def indefinite_getter(self):
article = "An" if self.singular[0] in "AEIO" else "A"
return f"get{article}{self.singular}"

@property
def type_is_class(self):
return bool(self.type) and self.type[0].isupper()

@property
def is_repeated(self):
return bool(self.plural)
Expand Down Expand Up @@ -191,6 +188,7 @@ class DbClasses:
template: ClassVar = 'ql_db'

classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)


@dataclass
Expand Down
28 changes: 22 additions & 6 deletions misc/codegen/lib/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from typing import List, Set, Union, Dict, Optional, FrozenSet
from enum import Enum, auto
import functools

Expand Down Expand Up @@ -87,8 +87,22 @@ class SynthInfo:


@dataclass
class Class:
class ClassBase:
imported: typing.ClassVar[bool]
name: str


@dataclass
class ImportedClass(ClassBase):
imported: typing.ClassVar[bool] = True

module: str


@dataclass
class Class(ClassBase):
imported: typing.ClassVar[bool] = False

bases: List[str] = field(default_factory=list)
derived: Set[str] = field(default_factory=set)
properties: List[Property] = field(default_factory=list)
Expand Down Expand Up @@ -133,7 +147,7 @@ def group(self) -> str:

@dataclass
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
classes: Dict[str, ClassBase] = field(default_factory=dict)
includes: List[str] = field(default_factory=list)
null: Optional[str] = None

Expand All @@ -155,7 +169,7 @@ def iter_properties(self, cls: str) -> Iterable[Property]:

predicate_marker = object()

TypeRef = Union[type, str]
TypeRef = type | str | ImportedClass


def get_type_name(arg: TypeRef) -> str:
Expand All @@ -164,6 +178,8 @@ def get_type_name(arg: TypeRef) -> str:
return arg.__name__
case str():
return arg
case ImportedClass():
return arg.name
case _:
raise Error(f"Not a schema type or string ({arg})")

Expand All @@ -172,9 +188,9 @@ def _make_property(arg: object) -> Property:
match arg:
case _ if arg is predicate_marker:
return PredicateProperty()
case str() | type():
case (str() | type() | ImportedClass()) as arg:
return SingleProperty(type=get_type_name(arg))
case Property():
case Property() as arg:
return arg
case _:
raise Error(f"Illegal property specifier {arg}")
Expand Down
7 changes: 4 additions & 3 deletions misc/codegen/lib/schemadefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import inspect as _inspect
from dataclasses import dataclass as _dataclass

from misc.codegen.lib.schema import Property

_set = set


Expand Down Expand Up @@ -69,6 +67,9 @@ def include(source: str):
_inspect.currentframe().f_back.f_locals.setdefault("includes", []).append(source)


imported = _schema.ImportedClass


@_dataclass
class _Namespace:
""" simple namespacing mechanism """
Expand Down Expand Up @@ -264,7 +265,7 @@ class _PropertyModifierList(_schema.PropertyModifier):
def __or__(self, other: _schema.PropertyModifier):
return _PropertyModifierList(self._mods + (other,))

def modify(self, prop: Property):
def modify(self, prop: _schema.Property):
for m in self._mods:
m.modify(prop)

Expand Down
6 changes: 5 additions & 1 deletion misc/codegen/loaders/schemaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
def load(m: types.ModuleType) -> schema.Schema:
includes = set()
classes = {}
imported_classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import misc.codegen.lib.schemadefs as defs
Expand All @@ -146,6 +147,9 @@ def load(m: types.ModuleType) -> schema.Schema:
continue
if isinstance(data, types.ModuleType):
continue
if isinstance(data, schema.ImportedClass):
imported_classes[name] = data
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
Expand All @@ -162,7 +166,7 @@ def load(m: types.ModuleType) -> schema.Schema:
_fill_hideable_information(classes)
_check_test_with(classes)

return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)


def load_file(path: pathlib.Path) -> schema.Schema:
Expand Down
2 changes: 1 addition & 1 deletion misc/codegen/templates/ql_class.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ module Generated {
*/
{{type}} {{getter}}({{#is_indexed}}int index{{/is_indexed}}) {
{{^synth}}
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_class}}Synth::convert{{type}}FromRaw({{/type_is_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_class}}){{/type_is_class}}
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_codegen_class}}Synth::convert{{type}}FromRaw({{/type_is_codegen_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_codegen_class}}){{/type_is_codegen_class}}
{{/synth}}
{{#synth}}
none()
Expand Down
4 changes: 4 additions & 0 deletions misc/codegen/templates/ql_db.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
* This module holds thin fully generated class definitions around DB entities.
*/
module Raw {
{{#imports}}
private import {{.}}
{{/imports}}

{{#classes}}
/**
* INTERNAL: Do not use.
Expand Down
15 changes: 0 additions & 15 deletions misc/codegen/test/test_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,6 @@ def test_property_has_first_table_param_marked():
assert [p.param for p in prop.tableparams] == tableparams


@pytest.mark.parametrize("type,expected", [
("Foo", True),
("Bar", True),
("foo", False),
("bar", False),
(None, False),
])
def test_property_is_a_class(type, expected):
tableparams = ["a", "result", "b"]
expected_tableparams = ["a", "result" if expected else "result", "b"]
prop = ql.Property("Prop", type, tableparams=tableparams)
assert prop.type_is_class is expected
assert [p.param for p in prop.tableparams] == expected_tableparams


indefinite_getters = [
("Argument", "getAnArgument"),
("Element", "getAnElement"),
Expand Down
4 changes: 3 additions & 1 deletion misc/codegen/test/test_qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def test_single_class_property(generate_classes, is_child, prev_child):
ql.Property(singular="Foo", type="Bar", tablename="my_objects",
tableparams=[
"this", "result"],
prev_child=prev_child, doc="foo of this my object"),
prev_child=prev_child, doc="foo of this my object",
type_is_codegen_class=True),
],
)),
"Bar.qll": (a_ql_class_public(name="Bar"), a_ql_stub(name="Bar"), a_ql_class(name="Bar", final=True, imports=[stub_import_prefix + "Bar"])),
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def test_hideable_property(generate_classes):
final=True, properties=[
ql.Property(singular="X", type="MyObject", tablename="others",
type_is_hideable=True,
type_is_codegen_class=True,
tableparams=["this", "result"], doc="x of this other"),
])),
}
Expand Down
2 changes: 1 addition & 1 deletion rust/extractor/src/generated/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b57a374

Please sign in to comment.