Skip to content

Commit

Permalink
Ensure that any field query uses the table name
Browse files Browse the repository at this point in the history
In order to include the table name for fields in this query, use the
`field_query` method.

Since `AnyFieldQuery` is just an `OrQuery` under the hood, remove it and
construct `OrQuery` explicitly instead.
  • Loading branch information
snejus committed May 10, 2024
1 parent 9ceffb6 commit 070c87f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 102 deletions.
43 changes: 0 additions & 43 deletions beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,49 +514,6 @@ def __hash__(self) -> int:
return reduce(mul, map(hash, self.subqueries), 1)


class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""

def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls

subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)

@property
def field_names(self) -> Set[str]:
return set(self.fields)

def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")

def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)

def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class

def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))


class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
Expand Down
16 changes: 5 additions & 11 deletions beets/dbcore/queryparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,13 @@ def construct_query_part(
query_part, query_classes, prefixes
)

# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)

# Field queries get constructed according to the name of the field
# they are querying.
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(query_class, pattern)
else:
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)

# Apply negation.
Expand Down
4 changes: 2 additions & 2 deletions beets/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def find_duplicates(self, lib):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
dup_query = library.Album.match_all_query(
{key: tmp_album.get(key) for key in keys}
)

Expand Down Expand Up @@ -1019,7 +1019,7 @@ def find_duplicates(self, lib):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Item.all_fields_query(
dup_query = library.Item.match_all_query(
{key: tmp_item.get(key) for key in keys}
)

Expand Down
13 changes: 12 additions & 1 deletion beets/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,18 @@ def field_query(
return query_cls(field, pattern, fast)

@classmethod
def all_fields_query(
def any_field_query(
cls, query_class: Type[dbcore.FieldQuery], pattern: str
) -> dbcore.OrQuery:
return dbcore.OrQuery(
[
cls.field_query(f, pattern, query_class)
for f in cls._search_fields
]
)

@classmethod
def match_all_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.
Expand Down
2 changes: 1 addition & 1 deletion test/test_dbcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def test_two_parts(self):
q = self.qfs(["foo", "bar:baz"])
self.assertIsInstance(q, dbcore.query.AndQuery)
self.assertEqual(len(q.subqueries), 2)
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)

def test_parse_fixed_type_query(self):
Expand Down
53 changes: 9 additions & 44 deletions test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,40 +48,6 @@ def assertNotInResult(self, item, results): # noqa
self.assertNotIn(item.id, result_ids)


class AnyFieldQueryTest(_common.LibTestCase):
def test_no_restriction(self):
q = dbcore.query.AnyFieldQuery(
"title",
beets.library.Item._fields.keys(),
dbcore.query.SubstringQuery,
)
self.assertEqual(self.lib.items(q).get().title, "the title")

def test_restriction_completeness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["title"], dbcore.query.SubstringQuery
)
self.assertEqual(self.lib.items(q).get().title, "the title")

def test_restriction_soundness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["artist"], dbcore.query.SubstringQuery
)
self.assertIsNone(self.lib.items(q).get())

def test_eq(self):
q1 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
q2 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
self.assertEqual(q1, q2)

q2.query_class = None
self.assertNotEqual(q1, q2)


class AssertsMixin:
def assert_items_matched(self, results, titles):
self.assertEqual({i.title for i in results}, set(titles))
Expand Down Expand Up @@ -981,14 +947,6 @@ def test_type_and(self):
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
self.assertNegationProperties(q)

def test_type_anyfield(self):
q = dbcore.query.AnyFieldQuery(
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
)
not_results = self.lib.items(dbcore.query.NotQuery(q))
self.assert_items_matched(not_results, ["baz qux"])
self.assertNegationProperties(q)

def test_type_boolean(self):
q = dbcore.query.BooleanQuery("comp", True)
not_results = self.lib.items(dbcore.query.NotQuery(q))
Expand Down Expand Up @@ -1137,11 +1095,18 @@ def test_get_items_filter_by_album_field(self):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])

def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
def test_filter_albums_by_common_field(self):
# title:Album1 ensures that the items table is joined for the query
q = "title:Album1 catalognum:ABC"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])

def test_filter_items_by_common_field(self):
# artpath:A ensures that the albums table is joined for the query
q = "artpath:A Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])

def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)
Expand Down

0 comments on commit 070c87f

Please sign in to comment.