Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Huizi Yu committed Jan 1, 2025
0 parents commit 446510f
Show file tree
Hide file tree
Showing 75 changed files with 5,920 additions and 0 deletions.
35 changes: 35 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# General
.DS_Store

# Python
__pycache__/
*.py[cod]
*.egg
*.egg-info/
dist/
build/
*.so

# Virtual Environment
venv/
env/
.venv/
.env/

# Jupyter Notebook
.ipynb_checkpoints/

# Logs
*.log
*.out
*.err

# IDEs/Editors
.idea/
.vscode/
*.swp

# OS Generated
.Trash-*
Thumbs.db

249 changes: 249 additions & 0 deletions AIPatient_Analysis/code/Neo4jDatabase/Neo4jDatabase_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from graph_construction.graph_construction_function.entity_creation import *
from neo4j import GraphDatabase
import pandas as pd

class Neo4jDatabase:
def __init__(self, uri, user, password):
"""Initialize the Neo4j driver"""
self.driver = GraphDatabase.driver(uri, auth=(user, password))
# Input data
self.nodes = [
{'labels': 'Patient', 'properties': ['SUBJECT_ID', 'GENDER', 'AGE', 'ETHNICITY', 'RELIGION', 'MARITAL_STATUS']},
{'labels': 'Admission', 'properties': ['HADM_ID', 'DURATION', 'ADMISSION_TYPE', 'ADMISSION_LOCATION', 'DISCHARGE_LOCATION', 'INSURANCE']},
{'labels': 'Symptom', 'properties': ['name']},
{'labels': 'Duration', 'properties': ['name']},
{'labels': 'Intensity', 'properties': ['name']},
{'labels': 'Frequency', 'properties': ['name']},
{'labels': 'History', 'properties': ['name']},
{'labels': 'Vital', 'properties': ['LABEL', 'VALUE']},
{'labels': 'Allergy', 'properties': ['name']},
{'labels': 'SocialHistory', 'properties': ['description']},
{'labels': 'FamilyMember', 'properties': ['name']},
{'labels': 'FamilyMedicalHistory', 'properties': ['name']}
]

self.relationships = [
{'relationship': 'HAS_ADMISSION', 'source': 'Patient', 'target': ['Admission']},
{'relationship': 'HAS_MEDICAL_HISTORY', 'source': 'Patient', 'target': ['History']},
{'relationship': 'HAS_FAMILY_MEMBER', 'source': 'Patient', 'target': ['FamilyMember']},
{'relationship': 'HAS_SYMPTOM', 'source': 'Admission', 'target': ['Symptom']},
{'relationship': 'HAS_SOCIAL_HISTORY', 'source': 'Admission', 'target': ['SocialHistory']},
{'relationship': 'HAS_VITAL', 'source': 'Admission', 'target': ['Vital']},
{'relationship': 'HAS_ALLERGY', 'source': 'Admission', 'target': ['Allergy']},
{'relationship': 'HAS_NOSYMPTOM', 'source': 'Admission', 'target': ['Symptom']},
{'relationship': 'HAS_DURATION', 'source': 'Symptom', 'target': ['Duration']},
{'relationship': 'HAS_INTENSITY', 'source': 'Symptom', 'target': ['Intensity']},
{'relationship': 'HAS_FREQUENCY', 'source': 'Symptom', 'target': ['Frequency']},
{'relationship': 'HAS_MEDICAL_HISTORY', 'source': 'FamilyMember', 'target': ['FamilyMedicalHistory']}
]

def close(self):
"""Close the Neo4j driver connection"""
self.driver.close()

def clear_database(self):
"""Clear all nodes and relationships in the database"""
with self.driver.session() as session:
session.run("MATCH (n) DETACH DELETE n")

def load_patients(self, df_patients):
"""Load patients into Neo4j"""
with self.driver.session() as session:
for index, row in df_patients.iterrows():
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
session.execute_write(create_patient, subject_id, row['GENDER'], row['AGE'], row['ETHNICITY'], row['RELIGION'], row['MARITAL_STATUS'])

def load_admission(self, df_admission):
"""Load patients into Neo4j"""
with self.driver.session() as session:
for index, row in df_admission.iterrows():
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
hadm_id = str(row['HADM_ID']) if pd.notna(row['HADM_ID']) else ''
session.execute_write(create_admission, hadm_id, subject_id, row['DURATION'], row['ADMISSION_TYPE'], row['ADMISSION_LOCATION'], row['DISCHARGE_LOCATION'], row['INSURANCE'])

def load_symptoms(self, df_symptoms):
"""Load symptoms into Neo4j"""
with self.driver.session() as session:
for index, row in df_symptoms.iterrows():
hadm_id = str(row['HADM_ID']) if pd.notna(row['HADM_ID']) else ''
symptom = '' if pd.isna(row['Symptom']) else row['Symptom']
duration = '' if pd.isna(row['Duration']) else row['Duration']
frequency = '' if pd.isna(row['Frequency']) else row['Frequency']
intensity = '' if pd.isna(row['Intensity']) else row['Intensity']
negation = '' if pd.isna(row['Negation']) else row['Negation']
session.execute_write(
create_symptom,
hadm_id,
symptom,
duration,
frequency,
intensity,
negation
)

def load_history(self, df_history):
"""Load medical history into Neo4j"""
with self.driver.session() as session:
for index, row in df_history.iterrows():
if pd.notna(row['Medical_History']):
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
session.execute_write(create_history, subject_id, row['Medical_History'])

def load_allergies(self, df_allergies):
"""Load allergies into Neo4j"""
with self.driver.session() as session:
for index, row in df_allergies.iterrows():
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
hadm_id = str(row['HADM_ID']) if pd.notna(row['HADM_ID']) else ''
session.execute_write(create_allergy,subject_id, hadm_id, row['Allergies'])

def load_vitals(self, df_vitals):
"""Load vitals into Neo4j"""
with self.driver.session() as session:
for index, row in df_vitals.iterrows():
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
hadm_id = str(row['HADM_ID']) if pd.notna(row['HADM_ID']) else ''
session.execute_write(create_vital, subject_id, hadm_id, row['LABEL'], row['VALUE'])

def load_social_history(self, df_social_history):
"""Load social history into Neo4j"""
with self.driver.session() as session:
for index, row in df_social_history.iterrows():
if pd.notna(row['Social_History']):
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
hadm_id = str(row['HADM_ID']) if pd.notna(row['HADM_ID']) else ''
session.execute_write(create_social_history, subject_id, hadm_id, row['Social_History'])

def load_family_history(self, df_family_history):
"""Load family history into Neo4j"""
with self.driver.session() as session:
for index, row in df_family_history.iterrows():
if pd.notna(row['Family_Medical_History']) and pd.notna(row['Family_Member']):
subject_id = str(row['SUBJECT_ID']) if pd.notna(row['SUBJECT_ID']) else ''
session.execute_write(create_family_history, subject_id, row['Family_Member'], row['Family_Medical_History'])

def load_all_data(self, df_patients, df_admission, df_symptoms, df_history, df_allergies, df_vitals, df_social_history, df_family_history):
"""Load all data into Neo4j"""
self.load_patients(df_patients)
self.load_admission(df_admission)
self.load_symptoms(df_symptoms)
self.load_history(df_history)
self.load_allergies(df_allergies)
self.load_vitals(df_vitals)
self.load_social_history(df_social_history)
self.load_family_history(df_family_history)


def db_creation_orchestrator(self, df_patients, df_admission, df_symptoms, df_history, df_allergies, df_vitals, df_social_history, df_family_history):
"""Orchestrates the data loading process into Neo4j"""
try:
# Step 1: Clear existing data in the database
print("Clearing the Neo4j database...")
self.clear_database()

# Step 2: Load data into Neo4j
print("Loading data into Neo4j...")
self.load_all_data(df_patients, df_admission, df_symptoms, df_history, df_allergies, df_vitals, df_social_history, df_family_history)

except Exception as e:
print(f"An error occurred during the data loading process: {e}")

finally:
# Step 3: Close the connection to Neo4j
print("Closing the Neo4j connection...")
self.close()
print("Data loading process completed.")



def get_random_patient_admission(self):
"""Fetch a random patient and admission"""
with self.driver.session() as session:
result = session.execute_read(self._fetch_random_patient_admission)
return result

@staticmethod
def _fetch_random_patient_admission(tx):
"""Fetch a random patient admission from the database"""
query = """
MATCH (p:Patient)-[:HAS_ADMISSION]->(a:Admission)
WITH p, a, rand() AS random
ORDER BY random
LIMIT 1
RETURN p.SUBJECT_ID AS SubjectID, a.HADM_ID AS AdmissionID
"""
result = tx.run(query)
return result.single()

def execute_cypher_query(self, cypher_query):
"""Execute a Cypher query"""
with self.driver.session() as session:
result = session.execute_read(self._run_cypher_query, cypher_query)
return result

@staticmethod
def _run_cypher_query(tx, cypher_query):
"""Run a Cypher query and return the results as a list of dictionaries"""
result = tx.run(cypher_query)
return [record.data() for record in result]

# New methods added based on your provided functions

def fetch_all_symptoms(self, hadm_id):
"""Fetch all symptoms for a given admission ID"""
query = f"""
MATCH (a:Admission {{HADM_ID: {hadm_id}}})-[r:HAS_SYMPTOM]->(s:Symptom)
RETURN a, r, s
"""
retrieved = self.execute_cypher_query(query)
symptoms = [entry['s']['name'] for entry in retrieved]
return query, symptoms

def fetch_all_medicalhistory(self, hadm_id):
"""Fetch all medical history for a given admission ID"""
query = f"""
MATCH (p:Patient)-[r:HAS_MEDICAL_HISTORY]->(h:History)
WHERE EXISTS((p)-[:HAS_ADMISSION]->(:Admission {{HADM_ID: {hadm_id}}}))
RETURN p, r, h
"""
retrieved = self.execute_cypher_query(query)
history = [entry['h']['name'] for entry in retrieved]
return query, history

def fetch_all_allergies(self, hadm_id):
"""Fetch all allergies for a given admission ID"""
query = f"""
MATCH (a:Admission {{HADM_ID: {hadm_id}}})-[r:HAS_ALLERGY]->(al:Allergy)
MATCH (p:Patient)-[:HAS_ADMISSION]->(a)
RETURN p, a, r, al
"""
retrieved = self.execute_cypher_query(query)
allergies = [entry['al']['name'] for entry in retrieved]
return query, allergies


## Reformat Schema
def reformat_schema(self, nodes, relationships):
# Format node properties
node_properties = "### Node Properties\n"
for node in nodes:
label = node['labels']
properties = ", ".join(node['properties'])
node_properties += f"- {label}: {properties}\n"

# Format relationships
relationship_info = "### Relationships\n"
for rel in relationships:
relationship = rel['relationship']
source = rel['source']
target = ", ".join(rel['target'])
relationship_info += f"- {source} -> {target}: {relationship}\n"

return node_properties + "\n" + relationship_info



# Call the function and print the result
def generate_schema(self):
schema_full = self.reformat_schema(self.nodes, self.relationships)
return schema_full
Loading

0 comments on commit 446510f

Please sign in to comment.