Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide base class for dataset loaders #59

Merged
merged 2 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions chemicalx/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from class_resolver import Resolver

from .contextfeatureset import ContextFeatureSet
from .datasetloader import DatasetLoader, DrugbankDDI, DrugComb, DrugCombDB, TwoSides
from .datasetloader import (
DatasetLoader,
DrugbankDDI,
DrugComb,
DrugCombDB,
RemoteDatasetLoader,
TwoSides,
)
from .drugfeatureset import DrugFeatureSet
from .drugpairbatch import DrugPairBatch
from .labeledtriples import LabeledTriples
Expand All @@ -22,4 +29,4 @@
"DrugCombDB",
]

dataset_resolver = Resolver.from_subclasses(base=DatasetLoader)
dataset_resolver = Resolver.from_subclasses(base=DatasetLoader, skip={RemoteDatasetLoader})
150 changes: 91 additions & 59 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import json
import urllib.request
from abc import ABC, abstractmethod
from functools import lru_cache
from textwrap import dedent
from typing import Dict, Optional, Tuple, cast
Expand All @@ -18,6 +19,7 @@

__all__ = [
"DatasetLoader",
"RemoteDatasetLoader",
# Actual datasets
"DrugCombDB",
"DrugComb",
Expand All @@ -26,18 +28,8 @@
]


class DatasetLoader:
"""General dataset loader for the integrated drug pair scoring datasets."""

def __init__(self, dataset_name: str):
"""Instantiate the dataset loader.
Args:
dataset_name (str): The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb", "twosides", "drugbankddi"]
class DatasetLoader(ABC):
"""A generic dataset."""

def get_generators(
self,
Expand Down Expand Up @@ -95,6 +87,86 @@ def get_generator(
labeled_triples=self.get_labeled_triples() if labeled_triples is None else labeled_triples,
)

@abstractmethod
def get_context_features(self) -> ContextFeatureSet:
"""
Get the context feature set.
Returns:
: The ContextFeatureSet of the dataset of interest.
"""

@property
def num_contexts(self) -> int:
"""Get the number of contexts."""
return len(self.get_context_features())

@property
def context_channels(self) -> int:
"""Get the number of features for each context."""
return next(iter(self.get_context_features().values())).shape[1]

@abstractmethod
def get_drug_features(self):
"""
Get the drug feature set.
Returns:
: The DrugFeatureSet of the dataset of interest.
"""

@property
def num_drugs(self) -> int:
"""Get the number of drugs."""
return len(self.get_drug_features())

@property
def drug_channels(self) -> int:
"""Get the number of features for each drug."""
return next(iter(self.get_drug_features().values()))["features"].shape[1]

def get_labeled_triples(self) -> LabeledTriples:
"""
Get the labeled triples file from the storage.
Returns:
: The labeled triples in the dataset.
"""

@property
def num_labeled_triples(self) -> int:
"""Get the number of labeled triples."""
return len(self.get_labeled_triples())

def summarize(self) -> None:
"""Summarize the dataset."""
print(
dedent(
f"""\
Name: {self.__class__.__name__}
Contexts: {self.num_contexts}
Context Feature Size: {self.context_channels}
Drugs: {self.num_drugs}
Drug Feature Size: {self.drug_channels}
Triples: {self.num_labeled_triples}
"""
)
)


class RemoteDatasetLoader(DatasetLoader):
"""General dataset loader for the integrated drug pair scoring datasets."""

def __init__(self, dataset_name: str):
"""Instantiate the dataset loader.
Args:
dataset_name (str): The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb", "twosides", "drugbankddi"]

def generate_path(self, file_name: str) -> str:
"""
Generate a complete url for a dataset file.
Expand Down Expand Up @@ -140,30 +212,20 @@ def get_context_features(self):
Get the context feature set.
Returns:
context_feature_set (ContextFeatureSet): The ContextFeatureSet of the dataset of interest.
: The ContextFeatureSet of the dataset of interest.
"""
path = self.generate_path("context_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: torch.FloatTensor(np.array(v).reshape(1, -1)) for k, v in raw_data.items()}
return ContextFeatureSet(raw_data)

@property
def num_contexts(self) -> int:
"""Get the number of contexts."""
return len(self.get_context_features())

@property
def context_channels(self) -> int:
"""Get the number of features for each context."""
return next(iter(self.get_context_features().values())).shape[1]

@lru_cache(maxsize=1)
def get_drug_features(self):
"""
Get the drug feature set.
Returns:
drug_feature_set (DrugFeatureSet): The DrugFeatureSet of the dataset of interest.
: The DrugFeatureSet of the dataset of interest.
"""
path = self.generate_path("drug_set.json")
raw_data = self.load_raw_json_data(path)
Expand All @@ -173,74 +235,44 @@ def get_drug_features(self):
}
return DrugFeatureSet.from_dict(raw_data)

@property
def num_drugs(self) -> int:
"""Get the number of drugs."""
return len(self.get_drug_features())

@property
def drug_channels(self) -> int:
"""Get the number of features for each drug."""
return next(iter(self.get_drug_features().values()))["features"].shape[1]

@lru_cache(maxsize=1)
def get_labeled_triples(self):
"""
Get the labeled triples file from the storage.
Returns:
labeled_triples (LabeledTriples): The labeled triples in the dataset.
: The labeled triples in the dataset.
"""
path = self.generate_path("labeled_triples.csv")
df = self.load_raw_csv_data(path)
return LabeledTriples(df)

@property
def num_labeled_triples(self) -> int:
"""Get the number of labeled triples."""
return len(self.get_labeled_triples())

def summarize(self) -> None:
"""Summarize the dataset."""
print(
dedent(
f"""\
Name: {self.dataset_name}
Contexts: {self.num_contexts}
Context Feature Size: {self.context_channels}
Drugs: {self.num_drugs}
Drug Feature Size: {self.drug_channels}
Triples: {self.num_labeled_triples}
"""
)
)


class DrugCombDB(DatasetLoader):
class DrugCombDB(RemoteDatasetLoader):
"""A dataset loader for `DrugCombDB <http://drugcombdb.denglab.org>`_."""

def __init__(self):
"""Instantiate the DrugCombDB dataset loader."""
super().__init__("drugcombdb")


class DrugComb(DatasetLoader):
class DrugComb(RemoteDatasetLoader):
"""A dataset loader for `DrugComb <https://drugcomb.fimm.fi/>`_."""

def __init__(self):
"""Instantiate the DrugComb dataset loader."""
super().__init__("drugcomb")


class TwoSides(DatasetLoader):
class TwoSides(RemoteDatasetLoader):
"""A dataset loader for a sample of `TWOSIDES <http://tatonettilab.org/offsides/>`_."""

def __init__(self):
"""Instantiate the TWOSIDES dataset loader."""
super().__init__("twosides")


class DrugbankDDI(DatasetLoader):
class DrugbankDDI(RemoteDatasetLoader):
"""A dataset loader for `Drugbank DDI <https://www.pnas.org/content/115/18/E4304>`_."""

def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
from typing import ClassVar

from chemicalx.data import DatasetLoader
from chemicalx.data import DatasetLoader, DrugCombDB


class TestGeneratorDrugCombDB(unittest.TestCase):
Expand All @@ -14,7 +14,7 @@ class TestGeneratorDrugCombDB(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the class with a dataset loader."""
cls.loader = DatasetLoader("drugcombdb")
cls.loader = DrugCombDB()

def test_all_true(self):
"""Test sizes of drug features during batch generation."""
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
from typing import ClassVar

from chemicalx.data import DatasetLoader
from chemicalx.data import DatasetLoader, DrugbankDDI, DrugComb, DrugCombDB, TwoSides


class TestDrugComb(unittest.TestCase):
Expand All @@ -14,7 +14,7 @@ class TestDrugComb(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugcomb")
cls.loader = DrugComb()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -40,7 +40,7 @@ class TestDrugCombDB(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugcombdb")
cls.loader = DrugCombDB()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -66,7 +66,7 @@ class TestDeepDDI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugbankddi")
cls.loader = DrugbankDDI()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -92,7 +92,7 @@ class TestTwoSides(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("twosides")
cls.loader = TwoSides()

def test_get_context_features(self):
"""Test the number of context features."""
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import chemicalx.models
from chemicalx import pipeline
from chemicalx.data import DatasetLoader, DrugCombDB
from chemicalx.data import DatasetLoader, DrugComb, DrugCombDB
from chemicalx.models import (
CASTER,
EPGCNDS,
Expand Down Expand Up @@ -111,7 +111,7 @@ class TestModels(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case with a dataset."""
cls.loader = DatasetLoader("drugcomb")
cls.loader = DrugComb()

def setUp(self):
"""Set up the test case."""
Expand Down