Skip to content

Commit

Permalink
implement graph storage
Browse files Browse the repository at this point in the history
  • Loading branch information
ILSparkle committed Dec 10, 2024
1 parent b21acc1 commit dea6da3
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:
SEARCH_TARGET: content
REDIS_URI: redis://localhost:6379
ELASTICSEARCH_URI: http://localhost:9200
GRAPH_STORAGE: neo4j
NEO4J_URI: bolt://localhost:7687
VECTORSTORE: chroma
CHROMA_PATH: ./chroma
MILVUS_URI: http://localhost:19530
Expand Down Expand Up @@ -91,6 +93,14 @@ jobs:
python -m pip install pytest
wget https://github.com/milvus-io/milvus/releases/download/v2.4.4/milvus-standalone-docker-compose.yml -O docker-compose.yml
sudo docker compose up -d
docker pull neo4j
docker run \
--name neo4j \
-d \
-p 7474:7474 \
-p 7687:7687 \
-e NEO4J_AUTH=none \
neo4j
- name: Test with pytest
run: |
Expand Down
2 changes: 2 additions & 0 deletions src/cardinal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .retriever import DenseRetriever, HybridRetriever, SparseRetriever
from .splitter import CJKTextSplitter, TextSplitter
from .storage import AutoStorage
from .graph import AutoGraphStorage
from .vectorstore import AutoVectorStore, AutoCondition


Expand All @@ -39,6 +40,7 @@
"CJKTextSplitter",
"TextSplitter",
"AutoStorage",
"AutoGraphStorage",
"AutoVectorStore",
"AutoCondition",
]
Expand Down
4 changes: 4 additions & 0 deletions src/cardinal/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .auto import AutoGraphStorage


__all__ = ["AutoGraphStorage"]
56 changes: 56 additions & 0 deletions src/cardinal/graph/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict, List, Optional, Sequence, Tuple, Type
from .neo4j import Neo4j
from .config import settings
from .schema import GraphStorage, T


class AutoGraphStorage(GraphStorage[T]):
def __init__(self, name: str) -> None:
self._graph_storage = _get_graph_storage()(name)

def insert_node(self, keys: Sequence[str], nodes: Sequence[T]) -> None:
return self._graph_storage.insert_node(keys, nodes)

def insert_edge(self, head_keys: Sequence[str], tail_keys: Sequence[str], edges: Sequence[T]) -> None:
return self._graph_storage.insert_edge(head_keys, tail_keys, edges)

def delete_node(self, key: str) -> None:
return self._graph_storage.delete_node(key)

def delete_edge(self, head_key: str, tail_key: str) -> None:
return self._graph_storage.delete_edge(head_key, tail_key)

def query_node(self, key: str) -> Optional[T]:
return self._graph_storage.query_node(key)

def query_edge(self, head_key: str, tail_key: str) -> Optional[T]:
return self._graph_storage.query_edge(head_key, tail_key)

def query_node_edges(self, key: str) -> Optional[List[T]]:
return self._graph_storage.query_node_edges(key)

def exists(self) -> bool:
return self._graph_storage.exists()

def destroy(self) -> None:
return self._graph_storage.destroy()


_graph_storages: Dict[str, Type["GraphStorage"]] = {}


def _add_graph_storage(name: str, storage: Type["GraphStorage"]) -> None:
_graph_storages[name] = storage


def _list_storages() -> List[str]:
return list(map(str, _graph_storages.keys()))


def _get_graph_storage() -> Type["GraphStorage"]:
if settings.graph_storage not in _graph_storages:
raise ValueError("Graph Storage not found, should be one of {}.".format(_list_storages()))
return _graph_storages[settings.graph_storage]


_add_graph_storage("neo4j", Neo4j)
15 changes: 15 additions & 0 deletions src/cardinal/graph/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
from dataclasses import dataclass
from typing import Optional


@dataclass
class Config:
graph_storage: str
neo4j_uri: Optional[str]


settings = Config(
graph_storage=os.environ.get("GRAPH_STORAGE"),
neo4j_uri=os.environ.get("NEO4J_URI")
)
91 changes: 91 additions & 0 deletions src/cardinal/graph/neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
from typing import Generic, Optional, Sequence, TypeVar
from pydantic import BaseModel
from neo4j import GraphDatabase

T = TypeVar("T", bound=BaseModel)

class Neo4j(Generic[T]):
def __init__(self, name: str) -> None:
self.name = name
self.driver = GraphDatabase.driver("bolt://localhost:7687")

def insert_node(self, key: Sequence[str], node: Sequence[T]) -> None:
with self.driver.session() as session:
for k, n in zip(key, node):
session.run(
"""
MERGE (n:Node {key: $key})
SET n += $properties
""",
key=k,
properties=n.model_dump(),
)

def insert_edge(self, head_key: Sequence[str], tail_key: Sequence[str], edge: Sequence[T]) -> None:
with self.driver.session() as session:
for h, t, e in zip(head_key, tail_key, edge):
session.run(
"""
MATCH (h:Node {key: $head_key})
MATCH (t:Node {key: $tail_key})
MERGE (h)-[r:EDGE {properties: $properties}]->(t)
""",
head_key=h,
tail_key=t,
properties=e,
)

def query_node(self, key: str) -> Optional[T]:
with self.driver.session() as session:
result = session.run(
"""
MATCH (n:Node {key: $key})
RETURN n
""",
key=key,
)
record = result.single()
if record:
return record["n"]
return None

def query_edge(self, head_key: str, tail_key: str) -> Optional[T]:
with self.driver.session() as session:
result = session.run(
"""
MATCH (h:Node {key: $head_key})-[r:EDGE]->(t:Node {key: $tail_key})
RETURN r
""",
head_key=head_key,
tail_key=tail_key,
)
record = result.single()
if record:
return record["r"]
return None

def query_node_edges(self, key: str) -> Optional[T]:
with self.driver.session() as session:
result = session.run(
"""
MATCH (n:Node {key: $key})-[r:EDGE]->()
RETURN r
""",
key=key,
)
edges = [record["r"] for record in result]
return edges if edges else None

def exists(self) -> bool:
try:
with self.driver.session() as session:
result = session.run("RETURN 1")
return result.single() is not None
except Exception:
return False

def destroy(self) -> None:
with self.driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")
self.driver.close()
80 changes: 80 additions & 0 deletions src/cardinal/graph/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from typing import Generic, Optional, Sequence, TypeVar

from pydantic import BaseModel


T = TypeVar("T", bound=BaseModel)


class GraphStorage(Generic[T], ABC):
name = None

@abstractmethod
def __init__(self, name: str) -> None:
r"""
Initializes a graph storage.
Args:
name: the name of the database.
"""
...

@abstractmethod
def insert_node(self, key: Sequence[str], node: Sequence[T]) -> None:
r"""
Inserts the node along with the key.
"""
...

@abstractmethod
def insert_edge(self, head_key: Sequence[str], tail_key: Sequence[str], edge: Sequence[T]) -> None:
r"""
Inserts the edge along with the key.
"""
...

@abstractmethod
def query_node(self, key: str) -> Optional[T]:
r"""
Gets the node associated with the given key.
Args:
key: the key to queried node.
"""
...

@abstractmethod
def query_edge(self, head_key: str, tail_key: str) -> Optional[T]:
r"""
Gets the edge associated with the given key.
Args:
head_key: the key to the head of queried edge.
tail_key: the key to the tail of queried edge.
"""
...

@abstractmethod
def query_node_edges(self, key: str) -> Optional[T]:
r"""
Gets all edges of the node associated with the given key.
Args:
key: the key to the head of queried node.
"""
...

@abstractmethod
def exists(self) -> bool:
r"""
Checks if the graph storage exists.
"""
...

@abstractmethod
def destroy(self) -> None:
r"""
Destroys the graph storage.
"""
...
59 changes: 59 additions & 0 deletions tests/graph/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from cardinal import AutoGraphStorage
from RagPanel.utils.graph_utils import Entity, Relation

def test_neo4j():
storage = AutoGraphStorage("test_db")
storage.destroy()

# 测试节点插入
entity_keys = ["entity1", "entity2"]
entities = [
Entity(name="Entity 1", type="Type A", desc="This is entity 1."),
Entity(name="Entity 2", type="Type B", desc="This is entity 2."),
]
storage.insert_node(entity_keys, entities)

# 查询节点
entity1 = storage.query_node("entity1")
assert entity1 is not None, "Entity 1 should exist"
# 手动从 Node 对象中提取属性
assert entity1["name"] == "Entity 1" and entity1["type"] == "Type A" and entity1["desc"] == "This is entity 1."

entity2 = storage.query_node("entity2")
assert entity2 is not None, "Entity 2 should exist"
# 手动从 Node 对象中提取属性
assert entity2["name"] == "Entity 2" and entity2["type"] == "Type B" and entity2["desc"] == "This is entity 2."

# 测试边插入
relation_keys_head = ["entity1"]
relation_keys_tail = ["entity2"]
relations = [
Relation(head="entity1", tail="entity2", desc="Connected", strength=1).model_dump_json()
]
storage.insert_edge(relation_keys_head, relation_keys_tail, relations)

# 查询边
relation = storage.query_edge("entity1", "entity2")
assert relation is not None, "Relation from entity1 to entity2 should exist"
import json
relation = json.loads(relation["properties"])
assert relation["desc"] == "Connected" and relation["strength"] == 1

# 查询节点的所有出边
entity1_relations = storage.query_node_edges("entity1")
assert entity1_relations is not None and len(entity1_relations) == 1, "Entity 1 should have one outgoing relation"

# 测试 exists 功能
assert storage.exists(), "Graph storage should exist"

# 测试 destroy 功能
storage.destroy()

# 验证清理是否成功
entity1_after_destroy = storage.query_node("entity1")
assert entity1_after_destroy is None, "Entity 1 should no longer exist after destroy"

entity2_after_destroy = storage.query_node("entity2")
assert entity2_after_destroy is None, "Entity 2 should no longer exist after destroy"

print("All tests passed!")

0 comments on commit dea6da3

Please sign in to comment.