Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

63-somnfeature-request---cdxml-file-as-user-input #67

Merged
merged 8 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/cfg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ kubernetes_jobs:

# Config for running SOMN job
somn:
image: "ianrinehart/somn:1.0"
image: "ianrinehart/somn:1.1"
command: "cp ${JOB_INPUT_DIR}/example_request.csv ${SOMN_PROJECT_DIR}/scratch/test_request.csv && micromamba run -n base somn predict last latest asdf && cp -r ${SOMN_PROJECT_DIR}/outputs/asdf/*/* ${JOB_OUTPUT_DIR}"
projectDirectory: '/tmp/somn_root/somn_scratch/IID-Models-2024'
imagePullPolicy: "Always"
Expand Down
73 changes: 55 additions & 18 deletions app/routers/somn.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,67 @@
from datetime import datetime
import re
from typing import List, Literal
from traceback import format_exc

from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException, status
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from openbabel import pybel as pb

from services.minio_service import MinIOService
from services.email_service import EmailService
from services.somn_service import SomnService, SomnException
from config import get_logger
from services.somn_service import SOMN_ERROR_TYPES, SomnService, SomnException

from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from models.somnRequestBody import SomnRequestBody
from models.sqlmodel.db import get_session
from models.sqlmodel.models import Job, JobType
from models.sqlmodel.models import JobType

from services.shared import draw_chemical_svg

router = APIRouter()
router = APIRouter()

log = get_logger(__name__)

class CheckReactionSiteRequest(BaseModel):
input: str
role: Literal['el', 'nuc']
input_type: Literal['smi', 'cml', 'cdxml']

class CheckReactionSiteResponse(BaseModel):
reaction_site_idxes: List[int]
smiles: str
has_chiral: bool
num_heavy_atoms: int
svg: str

class CheckReactionSiteResponseInvalid(BaseModel):
type: SOMN_ERROR_TYPES
message: str

@router.get(f"/{JobType.SOMN}/all-reaction-sites", tags=['Somn'], response_model=CheckReactionSiteResponse)
async def check_reaction_sites(smiles: str, role: Literal['el', 'nuc']):
@router.post(f"/{JobType.SOMN}/all-reaction-sites", tags=['Somn'], responses={
200: { "model": CheckReactionSiteResponse },
400: { "model": CheckReactionSiteResponseInvalid }
})
async def check_reaction_sites(
request: CheckReactionSiteRequest
):
input = request.input.replace('\"', '"')
role = request.role
input_type = request.input_type

input = input.strip()
if input_type == 'cml':
input = re.sub(r'> +<', '><', input)
elif input_type == 'cdxml':
input = re.sub(r' +', ' ', input)

try:
reactionSiteIdxes = SomnService.check_user_input_substrates(smiles, role)
reactionSiteIdxes, _, has_chiral, num_heavy_atoms = SomnService.check_user_input_substrates(input, input_type, role)
except SomnException as e:
raise HTTPException(status_code=400, detail=str(e))
return JSONResponse(
status_code=400,
content=e.__dict__
)
except Exception as e:
raise HTTPException(status_code=400, detail=str('Invalid user input'))
stack_trace = format_exc()
log.error(f"Error when checking reaction sites for {input}: {stack_trace}")
raise HTTPException(status_code=500, detail=str('Internal server error'))

def beforeDraw(d2d):
dopts = d2d.drawOptions()
Expand All @@ -41,10 +71,17 @@ def beforeDraw(d2d):
dopts.highlightRadius = .4
dopts.prepareMolsBeforeDrawing=True
dopts.fillHighlights=False

if input_type == 'cdxml' or input_type == 'cml':
mol = pb.readstring(input_type, input)
input = mol.write('smi').strip()

return {
"reaction_site_idxes": reactionSiteIdxes,
"svg": draw_chemical_svg(smiles,
"has_chiral": has_chiral,
"num_heavy_atoms": num_heavy_atoms,
"smiles": input,
"svg": draw_chemical_svg(input,
width=450,
height=300,
beforeDraw=beforeDraw,
Expand Down
188 changes: 150 additions & 38 deletions app/services/somn_service.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,146 @@
import io
import os
import time
import asyncio
import json

from fastapi import HTTPException
import pandas as pd
from sqlmodel.ext.asyncio.session import AsyncSession

from config import get_logger
from models.sqlmodel.models import Job
from models.enums import JobStatus

from services.minio_service import MinIOService
from services.email_service import EmailService

from typing import List, Literal
from rdkit import Chem
from openbabel import pybel as pb
import traceback

class SomnException(Exception):
pass
log = get_logger(__name__)
from openbabel import openbabel as ob

SOMN_ERROR_TYPES = Literal[
'invalid_input',
'3d_gen',
'no_reactive_nitrogens',
'no_br_or_cl_in_el',
'no_nitrogens_in_nuc',
'br_in_nuc',
'cl_in_nuc',
'no_reaction_site'
]

class SomnException(BaseException):
type: SOMN_ERROR_TYPES
message: str

def __init__(self, type: str, message: str):
self.type = type
self.message = message

class SomnService:
somn_frontend_baseURL = os.environ.get("SOMN_FRONTEND_URL")

@staticmethod
def gen3d_test(smiles: str):
try:
obmol = pb.readstring("smi", smiles)
obmol.addh()
def has_chiral(mol: pb.Molecule):
'''
Check if the molecule has any chiral centers (so we can prompt warning/info message) on front-end

Parameters:
mol (class 'openbabel.pybel.Molecule') - pybel.molecule object read from the SMILES String

Returns:
Bool - True: has stereochemistry, False: no stereochemistry
'''
m = mol.OBMol
for genericdata in m.GetAllData(ob.StereoData):
stereodata = ob.toStereoBase(genericdata)
stereotype = stereodata.GetType()
if (stereotype):
return True
return False

@staticmethod
def get_num_heavy_atoms(mol: pb.Molecule):
"""
Count the number of (non-Hydrogen) heavy atoms

Params:
mol (class 'openbabel.pybel.Molecule') - pybel.molecule object read from the SMILES String

Returns:
cntHeavyAtoms (int) - number of heavy atoms in the molecule
"""

# Remove Hs
mol.removeh()

# Count atoms
myVec=mol.atoms

# Get number of non-H atoms
cntHeavyAtoms=len(myVec)

# Add back Hs for further processing
mol.addh()

return cntHeavyAtoms

@staticmethod
def ob_test(
user_input: str,
input_type: Literal['smi', 'cml', 'cdxml']
):
"""
Generate 3D coordinates for a molecule from its SMILES string using OpenBabel.

Parameters:
smiles (str): SMILES representation of the molecule

Returns:
tuple[str, bool, int]: A tuple containing:
- The molecule in MOL2 format with 3D coordinates
- Boolean indicating if the molecule has chiral centers
- Number of heavy atoms

Raises:
SomnException: If SMILES string is invalid or 3D coordinate generation fails
"""
try:
obmol = pb.readstring(input_type, user_input)
except Exception as e:
raise SomnException(type="invalid_input", message=f"Invalid input [{input_type}]: {user_input}")

obmol.addh()

try:
# this step may fail, so we know SOMN cannot compute on the input
obmol.make3D()

return obmol.write("mol2")

except Exception as e:
raise SomnException(f"Unable to generate 3D coordinates for {smiles}")
log.error(f"Unable to generate 3D coordinates {traceback.print_exc()}")
raise SomnException(type="3d_gen", message=f"Unable to generate 3D coordinates for {user_input}")

return (
obmol.write("mol2"),
SomnService.has_chiral(obmol),
SomnService.get_num_heavy_atoms(obmol)
)

@staticmethod
def validate_and_update_config(job_config: dict):
el_mol_str = SomnService.gen3d_test(job_config['el'])
reaction_sites = SomnService.check_user_input_substrates(job_config['el'], 'el')
reaction_sites, el_mol_str, _, _ = SomnService.check_user_input_substrates(job_config['el'], job_config['el_input_type'], 'el')

if len(reaction_sites) > 1:
job_config['el'] = el_mol_str
if len(reaction_sites) > 1 or \
job_config['el_input_type'] == 'cdxml' or \
job_config['el_input_type'] == 'cml':
job_config['el'] = el_mol_str

nuc_mol_str = SomnService.gen3d_test(job_config['nuc'])
reaction_sites = SomnService.check_user_input_substrates(job_config['nuc'], 'nuc')
reaction_sites, nuc_mol_str, _, _ = SomnService.check_user_input_substrates(job_config['nuc'], job_config['nuc_input_type'], 'nuc')

if len(reaction_sites) > 1:
job_config['nuc'] = nuc_mol_str
if len(reaction_sites) > 1 or \
job_config['nuc_input_type'] == 'cdxml' or \
job_config['nuc_input_type'] == 'cml':
job_config['nuc'] = nuc_mol_str

return job_config

Expand Down Expand Up @@ -84,11 +174,27 @@ async def resultPostProcess(bucket_name: str, job_id: str, service: MinIOService
return retVal

@staticmethod
def check_user_input_substrates(user_input, role: str):
def check_user_input_substrates(
user_input,
input_type: Literal['smi', 'cml', 'cdxml'],
role: Literal['el', 'nuc']
):
"""
Verifies user input substrate, and if verification is successful, returns reaction sites if multiple.
If everything is "normal", i.e., the user doesn't need to tell us more information, then it returns 0.
If NONE are found for halides or reactive nitrogens, then this returns None.
Verifies user input substrate and returns reaction site indices and chirality information.

Parameters:
user_input (str): SMILES string of the input molecule
role (str): Role of the molecule - either 'el' for electrophile or 'nuc' for nucleophile

Returns:
tuple: A tuple containing:
- list[int]: List of atom indices for reaction sites
- a mol2 string of the molecule with 3D coordinates
- bool: True if molecule has chiral centers, False otherwise

Raises:
SomnException: If no valid reaction sites are found or if invalid molecule type
Exception: For other validation errors
"""

def get_amine_ref_ns(mol, ref_atom_idxes: List) -> List:
Expand All @@ -113,7 +219,7 @@ def get_amine_ref_ns(mol, ref_atom_idxes: List) -> List:
if len(ret_val) >= 1:
return ret_val

raise Exception("No reactive nitrogens detected in nucleophile!")
raise SomnException(type="no_reactive_nitrogens", message="No reactive nitrogens found")

def check_halides_aromatic(rdkmol,halides):
rdkatoms = [atom for atom in rdkmol.GetAtoms() if atom.GetIdx() in halides]
Expand All @@ -136,13 +242,13 @@ def get_atoms_by_symbol(mol, symbol):
retVals.append(idx)
return retVals

# add gen3d_test here to prevent the user from
# add ob_test here to prevent the user from
# submitting a molecule that cannot be processed by SOMN
obmol = SomnService.gen3d_test(user_input)
(mol2, has_chiral, num_heavy_atoms) = SomnService.ob_test(user_input, input_type)

# generate rdkit mol using obmol because rdkit
# generate rdkit mol using mol2 because rdkit
# might have issues generating mol from smiles directly
mol = Chem.MolFromMol2Block(obmol, sanitize=False, removeHs=False)
mol = Chem.MolFromMol2Block(mol2, sanitize=False, removeHs=False)

bromides = get_atoms_by_symbol(mol, symbol="Br")
chlorides = get_atoms_by_symbol(mol, symbol="Cl")
Expand All @@ -153,30 +259,36 @@ def get_atoms_by_symbol(mol, symbol):
if len(bromides) != 0:
aromatic_halides = check_halides_aromatic(mol,bromides)
if any(idx in aromatic_halides for idx in bromides):
raise SomnException("Bromides detected in nucleophile!")
raise SomnException(type="br_in_nuc", message="Bromine found in nucleophile")

if len(chlorides) != 0:
aromatic_halides = check_halides_aromatic(mol,bromides)
if any(idx in aromatic_halides for idx in chlorides):
raise SomnException("Chlorides detected in nucleophile!")
raise SomnException(type="cl_in_nuc", message="Chlorine found in nucleophile")

if len(nitrogens) == 0:
raise SomnException("No nitrogens detected in nucleophile!")
raise SomnException(type="no_nitrogens_in_nuc", message="No nitrogens found in nucleophile")

indices = get_amine_ref_ns(mol,nitrogens)
return indices
if not len(indices):
raise SomnException(type="no_reaction_site", message="No Reaction Site found")

return (indices, mol2, has_chiral, num_heavy_atoms)

elif role.startswith('el'):
if len(bromides) + len(chlorides) == 0:
raise SomnException("No Br or Cl sites detected in electrophile!")
raise SomnException(type="no_br_or_cl_in_el", message="No Br or Cl found in electrophile")

try:
chl_idxes = check_halides_aromatic(mol,chlorides)
brm_idxes = check_halides_aromatic(mol,bromides)

if not len(chl_idxes) and not len(brm_idxes):
raise SomnException(type="no_reaction_site", message="No Reaction Site found")

return chl_idxes + brm_idxes
return (chl_idxes + brm_idxes, mol2, has_chiral, num_heavy_atoms)

except Exception as e:
raise e

return []
raise SomnException(type="invalid_input", message=f"Invalid input role: {role}")
2 changes: 1 addition & 1 deletion chart/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ config:

# Config for running SOMN job
somn:
image: "ianrinehart/somn:1.0"
image: "ianrinehart/somn:1.1"
projectDirectory: '/tmp/somn_root/somn_scratch/IID-Models-2024'
command: "cp ${JOB_INPUT_DIR}/example_request.csv ${SOMN_PROJECT_DIR}/scratch/test_request.csv && micromamba run -n base somn predict last latest asdf && cp -r ${SOMN_PROJECT_DIR}/outputs/asdf/*/* ${JOB_OUTPUT_DIR}"
imagePullPolicy: "Always"
Expand Down
Loading