-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .auto import AutoGraphStorage | ||
|
||
|
||
__all__ = ["AutoGraphStorage"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Dict, List, Optional, Sequence, Tuple, Type | ||
from .neo4j import Neo4j | ||
from .config import settings | ||
from .schema import GraphStorage, T | ||
|
||
|
||
class AutoGraphStorage(GraphStorage[T]): | ||
def __init__(self, name: str) -> None: | ||
self._graph_storage = _get_graph_storage()(name) | ||
|
||
def insert_node(self, keys: Sequence[str], nodes: Sequence[T]) -> None: | ||
return self._graph_storage.insert_node(keys, nodes) | ||
|
||
def insert_edge(self, head_keys: Sequence[str], tail_keys: Sequence[str], edges: Sequence[T]) -> None: | ||
return self._graph_storage.insert_edge(head_keys, tail_keys, edges) | ||
|
||
def delete_node(self, key: str) -> None: | ||
return self._graph_storage.delete_node(key) | ||
|
||
def delete_edge(self, head_key: str, tail_key: str) -> None: | ||
return self._graph_storage.delete_edge(head_key, tail_key) | ||
|
||
def query_node(self, key: str) -> Optional[T]: | ||
return self._graph_storage.query_node(key) | ||
|
||
def query_edge(self, head_key: str, tail_key: str) -> Optional[T]: | ||
return self._graph_storage.query_edge(head_key, tail_key) | ||
|
||
def query_node_edges(self, key: str) -> Optional[List[T]]: | ||
return self._graph_storage.query_node_edges(key) | ||
|
||
def exists(self) -> bool: | ||
return self._graph_storage.exists() | ||
|
||
def destroy(self) -> None: | ||
return self._graph_storage.destroy() | ||
|
||
|
||
_graph_storages: Dict[str, Type["GraphStorage"]] = {} | ||
|
||
|
||
def _add_graph_storage(name: str, storage: Type["GraphStorage"]) -> None: | ||
_graph_storages[name] = storage | ||
|
||
|
||
def _list_storages() -> List[str]: | ||
return list(map(str, _graph_storages.keys())) | ||
|
||
|
||
def _get_graph_storage() -> Type["GraphStorage"]: | ||
if settings.graph_storage not in _graph_storages: | ||
raise ValueError("Graph Storage not found, should be one of {}.".format(_list_storages())) | ||
return _graph_storages[settings.graph_storage] | ||
|
||
|
||
_add_graph_storage("neo4j", Neo4j) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class Config: | ||
graph_storage: str | ||
neo4j_uri: Optional[str] | ||
|
||
|
||
settings = Config( | ||
graph_storage=os.environ.get("GRAPH_STORAGE"), | ||
neo4j_uri=os.environ.get("NEO4J_URI") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import json | ||
from typing import Generic, Optional, Sequence, TypeVar | ||
from pydantic import BaseModel | ||
from neo4j import GraphDatabase | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
class Neo4j(Generic[T]): | ||
def __init__(self, name: str) -> None: | ||
self.name = name | ||
self.driver = GraphDatabase.driver("bolt://localhost:7687") | ||
|
||
def insert_node(self, key: Sequence[str], node: Sequence[T]) -> None: | ||
with self.driver.session() as session: | ||
for k, n in zip(key, node): | ||
session.run( | ||
""" | ||
MERGE (n:Node {key: $key}) | ||
SET n += $properties | ||
""", | ||
key=k, | ||
properties=n.model_dump(), | ||
) | ||
|
||
def insert_edge(self, head_key: Sequence[str], tail_key: Sequence[str], edge: Sequence[T]) -> None: | ||
with self.driver.session() as session: | ||
for h, t, e in zip(head_key, tail_key, edge): | ||
session.run( | ||
""" | ||
MATCH (h:Node {key: $head_key}) | ||
MATCH (t:Node {key: $tail_key}) | ||
MERGE (h)-[r:EDGE {properties: $properties}]->(t) | ||
""", | ||
head_key=h, | ||
tail_key=t, | ||
properties=e, | ||
) | ||
|
||
def query_node(self, key: str) -> Optional[T]: | ||
with self.driver.session() as session: | ||
result = session.run( | ||
""" | ||
MATCH (n:Node {key: $key}) | ||
RETURN n | ||
""", | ||
key=key, | ||
) | ||
record = result.single() | ||
if record: | ||
return record["n"] | ||
return None | ||
|
||
def query_edge(self, head_key: str, tail_key: str) -> Optional[T]: | ||
with self.driver.session() as session: | ||
result = session.run( | ||
""" | ||
MATCH (h:Node {key: $head_key})-[r:EDGE]->(t:Node {key: $tail_key}) | ||
RETURN r | ||
""", | ||
head_key=head_key, | ||
tail_key=tail_key, | ||
) | ||
record = result.single() | ||
if record: | ||
return record["r"] | ||
return None | ||
|
||
def query_node_edges(self, key: str) -> Optional[T]: | ||
with self.driver.session() as session: | ||
result = session.run( | ||
""" | ||
MATCH (n:Node {key: $key})-[r:EDGE]->() | ||
RETURN r | ||
""", | ||
key=key, | ||
) | ||
edges = [record["r"] for record in result] | ||
return edges if edges else None | ||
|
||
def exists(self) -> bool: | ||
try: | ||
with self.driver.session() as session: | ||
result = session.run("RETURN 1") | ||
return result.single() is not None | ||
except Exception: | ||
return False | ||
|
||
def destroy(self) -> None: | ||
with self.driver.session() as session: | ||
session.run("MATCH (n) DETACH DELETE n") | ||
self.driver.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, Optional, Sequence, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
class GraphStorage(Generic[T], ABC): | ||
name = None | ||
|
||
@abstractmethod | ||
def __init__(self, name: str) -> None: | ||
r""" | ||
Initializes a graph storage. | ||
Args: | ||
name: the name of the database. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def insert_node(self, key: Sequence[str], node: Sequence[T]) -> None: | ||
r""" | ||
Inserts the node along with the key. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def insert_edge(self, head_key: Sequence[str], tail_key: Sequence[str], edge: Sequence[T]) -> None: | ||
r""" | ||
Inserts the edge along with the key. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def query_node(self, key: str) -> Optional[T]: | ||
r""" | ||
Gets the node associated with the given key. | ||
Args: | ||
key: the key to queried node. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def query_edge(self, head_key: str, tail_key: str) -> Optional[T]: | ||
r""" | ||
Gets the edge associated with the given key. | ||
Args: | ||
head_key: the key to the head of queried edge. | ||
tail_key: the key to the tail of queried edge. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def query_node_edges(self, key: str) -> Optional[T]: | ||
r""" | ||
Gets all edges of the node associated with the given key. | ||
Args: | ||
key: the key to the head of queried node. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def exists(self) -> bool: | ||
r""" | ||
Checks if the graph storage exists. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def destroy(self) -> None: | ||
r""" | ||
Destroys the graph storage. | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from cardinal import AutoGraphStorage | ||
from RagPanel.utils.graph_utils import Entity, Relation | ||
|
||
def test_neo4j(): | ||
storage = AutoGraphStorage("test_db") | ||
storage.destroy() | ||
|
||
# 测试节点插入 | ||
entity_keys = ["entity1", "entity2"] | ||
entities = [ | ||
Entity(name="Entity 1", type="Type A", desc="This is entity 1."), | ||
Entity(name="Entity 2", type="Type B", desc="This is entity 2."), | ||
] | ||
storage.insert_node(entity_keys, entities) | ||
|
||
# 查询节点 | ||
entity1 = storage.query_node("entity1") | ||
assert entity1 is not None, "Entity 1 should exist" | ||
# 手动从 Node 对象中提取属性 | ||
assert entity1["name"] == "Entity 1" and entity1["type"] == "Type A" and entity1["desc"] == "This is entity 1." | ||
|
||
entity2 = storage.query_node("entity2") | ||
assert entity2 is not None, "Entity 2 should exist" | ||
# 手动从 Node 对象中提取属性 | ||
assert entity2["name"] == "Entity 2" and entity2["type"] == "Type B" and entity2["desc"] == "This is entity 2." | ||
|
||
# 测试边插入 | ||
relation_keys_head = ["entity1"] | ||
relation_keys_tail = ["entity2"] | ||
relations = [ | ||
Relation(head="entity1", tail="entity2", desc="Connected", strength=1).model_dump_json() | ||
] | ||
storage.insert_edge(relation_keys_head, relation_keys_tail, relations) | ||
|
||
# 查询边 | ||
relation = storage.query_edge("entity1", "entity2") | ||
assert relation is not None, "Relation from entity1 to entity2 should exist" | ||
import json | ||
relation = json.loads(relation["properties"]) | ||
assert relation["desc"] == "Connected" and relation["strength"] == 1 | ||
|
||
# 查询节点的所有出边 | ||
entity1_relations = storage.query_node_edges("entity1") | ||
assert entity1_relations is not None and len(entity1_relations) == 1, "Entity 1 should have one outgoing relation" | ||
|
||
# 测试 exists 功能 | ||
assert storage.exists(), "Graph storage should exist" | ||
|
||
# 测试 destroy 功能 | ||
storage.destroy() | ||
|
||
# 验证清理是否成功 | ||
entity1_after_destroy = storage.query_node("entity1") | ||
assert entity1_after_destroy is None, "Entity 1 should no longer exist after destroy" | ||
|
||
entity2_after_destroy = storage.query_node("entity2") | ||
assert entity2_after_destroy is None, "Entity 2 should no longer exist after destroy" | ||
|
||
print("All tests passed!") |