Skip to content

Commit

Permalink
Merge pull request #8 from cloudspannerecosystem/re-architecture
Browse files Browse the repository at this point in the history
New architecture to fetch graph query results async.
  • Loading branch information
cqian23 authored Dec 20, 2024
2 parents a60d220 + f5486db commit c055d61
Show file tree
Hide file tree
Showing 16 changed files with 836 additions and 339 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def package_files(directory):

setup(
name="spanner-graph-notebook",
version="v1.0.5",
version="v1.0.6",
packages=find_packages(),
install_requires=[
"networkx", "numpy", "google-cloud-spanner", "ipython",
"ipywidgets", "notebook"
"ipywidgets", "notebook", "requests", "portpicker"
],
include_package_data=True,
description='Visually query Spanner Graph data in notebooks.',
Expand Down
29 changes: 21 additions & 8 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ def __init__(self, project_id: str, instance_id: str,
self.instance = self.client.instance(instance_id)
self.database = self.instance.database(database_id)

# In order to ensure that the database connection was properly
# created and that customers won't be confused by connectivity
# errors happening mysteriously, let's firstly run a dummy
# query on the database.
sql = 'SELECT table_name FROM information_schema.tables WHERE table_schema = ""'
_ = self.execute_query(sql, None, is_test_query=True)

def __repr__(self) -> str:
return (f"<SpannerDatabase["
f"project:{self.client.project_name},"
Expand Down Expand Up @@ -174,7 +167,7 @@ def __init__(self):
def execute_query(
self,
_: str,
limit: int = None
limit: int = 5
) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]:
"""Mock execution of query"""

Expand All @@ -197,3 +190,23 @@ def execute_query(
data[field.name].append(value)

return data, fields, rows, self.schema_json


database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = {
# "project_instance_database": SpannerDatabase
}


def get_database_instance(project: str, instance: str, database: str, mock = False):
if mock:
return MockSpannerDatabase()

key = f"{project}_{instance}_{database}"

db = database_instances.get(key, None)
if not db:
# Now create and insert it.
db = SpannerDatabase(project, instance, database)
database_instances[key] = db

return db
160 changes: 160 additions & 0 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import http.server
import socketserver
import json
import threading
import requests
import portpicker
from networkx.classes import DiGraph

from spanner_graphs.conversion import prepare_data_for_graphing, columns_to_native_numpy
from spanner_graphs.database import get_database_instance


def execute_query(project: str, instance: str, database: str, query: str, mock = False):
database = get_database_instance(project, instance, database, mock)

try:
query_result, fields, rows, schema_json = database.execute_query(query)
d, ignored_columns = columns_to_native_numpy(query_result, fields)

graph: DiGraph = prepare_data_for_graphing(
incoming=d,
schema_json=schema_json)

nodes = []
for (node_id, node) in graph.nodes(data=True):
nodes.append(node)

edges = []
for (from_id, to_id, edge) in graph.edges(data=True):
edges.append(edge)

return {
"response": {
"nodes": nodes,
"edges": edges,
"schema": schema_json,
"rows": rows
}
}
except Exception as e:
return {
"error": getattr(e, "message", str(e))
}


class GraphServer:
port = portpicker.pick_unused_port()
host = 'http://localhost'
url = f"{host}:{port}"

endpoints = {
"get_ping": "/get_ping",
"post_ping": "/post_ping",
"post_query": "/post_query",
}

@staticmethod
def build_route(endpoint):
return f"{GraphServer.url}{endpoint}"

@staticmethod
def start_server():
with socketserver.TCPServer(("", GraphServer.port), GraphServerHandler) as httpd:
print(f"Spanner Graph notebook loaded")
httpd.serve_forever()

@staticmethod
def init():
server_thread = threading.Thread(target=GraphServer.start_server)
server_thread.start()
return server_thread

@staticmethod
def get_ping():
route = GraphServer.build_route(GraphServer.endpoints["get_ping"])
response = requests.get(route)

if response.status_code == 200:
return response.json()
else:
print(f"Request failed with status code {response.status_code}")
return False

@staticmethod
def post_ping(data):
route = GraphServer.build_route(GraphServer.endpoints["post_ping"])
response = requests.post(route, json=data)

if response.status_code == 200:
return response.json()
else:
print(f"Request failed with status code {response.status_code}")
return False

class GraphServerHandler(http.server.SimpleHTTPRequestHandler):
def log_message(self, format, *args):
pass

def do_json_response(self, data):
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header('Content-type', 'application/json')
self.send_header("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
self.end_headers()
self.wfile.write(json.dumps(data).encode())

def do_message_response(self, message):
self.do_json_response({'message': message})

def do_data_response(self, data):
self.do_json_response(data)

def parse_post_data(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length).decode('utf-8')
return json.loads(post_data)

def handle_get_ping(self):
self.do_message_response('pong')

def handle_post_ping(self):
data = self.parse_post_data()
self.do_data_response({'your_request': data})

def handle_post_query(self):
data = self.parse_post_data()
response = execute_query(
project=data["project"],
instance=data["instance"],
database=data["database"],
query=data["query"],
mock=data["mock"]
)
self.do_data_response(response)

def do_GET(self):
if self.path == GraphServer.endpoints["get_ping"]:
self.handle_get_ping()
else:
super().do_GET()

def do_POST(self):
if self.path == GraphServer.endpoints["post_ping"]:
self.handle_post_ping()
elif self.path == GraphServer.endpoints["post_query"]:
self.handle_post_query()
Loading

0 comments on commit c055d61

Please sign in to comment.