-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/dbn #174
base: develop
Are you sure you want to change the base?
Feature/dbn #174
Changes from 5 commits
6acc09f
e1ac8e0
15e23cb
0c3fd19
71b2f93
7d495fc
326f745
85df54f
e34376c
acd9dfe
8fcf355
2143e52
5c8505b
6d0b22b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ | |
__all__ = ["BayesianNetwork"] | ||
|
||
from .network import BayesianNetwork | ||
from .network import DynamicBayesianNetwork |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ | |
from pgmpy.models import BayesianModel | ||
|
||
from causalnex.estimator.em import EMSingleLatentVariable | ||
from causalnex.structure import StructureModel | ||
from causalnex.structure import StructureModel, DynamicStructureModel | ||
from causalnex.utils.pgmpy_utils import pd_to_tabular_cpd | ||
|
||
|
||
|
@@ -736,3 +736,84 @@ def _predict_probability_from_incomplete_data( | |
probability = probability[cols] | ||
probability.columns = cols | ||
return probability | ||
|
||
|
||
class DynamicBayesianNetwork(BayesianNetwork): | ||
""" | ||
Base class for Dynamic Bayesian Network (DBN), a probabilistic weighted DAG where nodes represent variables, | ||
edges represent the causal relationships between variables. | ||
|
||
``DynamicBayesianNetwork`` stores nodes with their possible states, edges and | ||
conditional probability distributions (CPDs) of each node. | ||
|
||
``DynamicBayesianNetwork`` is built on top of the ``StructureModel``, which is an extension of ``networkx.DiGraph`` | ||
(see :func:`causalnex.structure.structuremodel.StructureModel`). | ||
|
||
In order to define the ``DynamicBayesianNetwork``, users should provide a relevant ``StructureModel``. | ||
Once ``DynamicBayesianNetwork`` is initialised, no changes to the ``StructureModel`` can be made | ||
and CPDs can be learned from the data. | ||
|
||
The learned CPDs can be then used for likelihood estimation and predictions. | ||
|
||
Example: | ||
:: | ||
>>> # Create a Dynamic Bayesian Network with a manually defined DAG. | ||
>>> from causalnex.structure import StructureModel | ||
>>> from causalnex.network import DynamicBayesianNetwork | ||
>>> | ||
>>> sm = StructureModel() | ||
>>> sm.add_edges_from([ | ||
>>> ('rush_hour', 'traffic'), | ||
>>> ('weather', 'traffic') | ||
>>> ]) | ||
>>> dbn = DynamicBayesianNetwork(sm) | ||
>>> # A created ``DynamicBayesianNetwork`` stores nodes and edges defined by the ``StructureModel`` | ||
>>> dbn.nodes | ||
['rush_hour', 'traffic', 'weather'] | ||
>>> | ||
>>> dbn.edges | ||
[('rush_hour', 'traffic'), ('weather', 'traffic')] | ||
>>> # A ``DynamicBayesianNetwork`` doesn't store any CPDs yet | ||
>>> dbn.cpds | ||
>>> {} | ||
>>> | ||
>>> # Learn the nodes' states from the data | ||
>>> import pandas as pd | ||
>>> data = pd.DataFrame({ | ||
>>> 'rush_hour': [True, False, False, False, True, False, True], | ||
>>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], | ||
>>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] | ||
>>> }) | ||
>>> dbn = dbn.fit_node_states(data) | ||
>>> dbn.node_states | ||
{'rush_hour': {False, True}, 'weather': {'Bad', 'Good', 'Terrible'}, 'traffic': {'heavy', 'light'}} | ||
>>> # Learn the CPDs from the data | ||
>>> dbn = dbn.fit_cpds(data) | ||
>>> # Use the learned CPDs to make predictions on the unseen data | ||
>>> test_data = pd.DataFrame({ | ||
>>> 'rush_hour': [False, False, True, True], | ||
>>> 'weather': ['Good', 'Bad', 'Good', 'Bad'] | ||
>>> }) | ||
>>> dbn.predict(test_data, "traffic").to_dict() | ||
>>> {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} | ||
>>> dbn.predict_probability(test_data, "traffic").to_dict() | ||
{'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} | ||
{'traffic_light': {0: 0.75, 1: 0.25, 2: 0.3333333333333333, 3: 0.3333333333333333}, | ||
'traffic_heavy': {0: 0.25, 1: 0.75, 2: 0.6666666666666666, 3: 0.6666666666666666}} | ||
""" | ||
|
||
def __init__(self, structure: DynamicStructureModel): | ||
""" | ||
Create a ``DynamicBayesianNetwork`` with a DAG defined by ``DynamicStructureModel``. | ||
|
||
Args: | ||
structure: a graph representing a causal relationship between variables. | ||
In the structure | ||
- cycles are not allowed; | ||
- multiple (parallel) edges are not allowed; | ||
- isolated nodes and multiple components are not allowed. | ||
|
||
Raises: | ||
ValueError: If the structure is not a connected DAG. | ||
""" | ||
super().__init__(structure) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wanted to know if you're clear on the changes that will need to come here. If not let's have a PS anytime :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes i think i need a PS here :) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,8 @@ | |
import scipy.linalg as slin | ||
import scipy.optimize as sopt | ||
|
||
from causalnex.structure import StructureModel | ||
from causalnex.structure import DynamicStructureModel | ||
from causalnex.structure import DynamicStructureNode | ||
from causalnex.structure.transformers import DynamicDataTransformer | ||
|
||
|
||
|
@@ -53,7 +54,7 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments | |
tabu_edges: List[Tuple[int, int, int]] = None, | ||
tabu_parent_nodes: List[int] = None, | ||
tabu_child_nodes: List[int] = None, | ||
) -> StructureModel: | ||
) -> DynamicStructureModel: | ||
liam-adams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in | ||
data. The input data is a time series or a list of realisations of a same time series. | ||
|
@@ -122,15 +123,15 @@ def from_pandas_dynamic( # pylint: disable=too-many-arguments | |
tabu_child_nodes, | ||
) | ||
|
||
sm = StructureModel() | ||
sm.add_nodes_from( | ||
[f"{var}_lag{l_val}" for var in col_idx.keys() for l_val in range(p + 1)] | ||
sm = DynamicStructureModel() | ||
sm.add_nodes( | ||
[DynamicStructureNode(var, l_val) for var in col_idx.keys() for l_val in range(p + 1)] | ||
) | ||
sm.add_weighted_edges_from( | ||
[ | ||
( | ||
_format_name_from_pandas(idx_col, u), | ||
_format_name_from_pandas(idx_col, v), | ||
DynamicStructureNode(idx_col[int(u[0])], u[-1]), # _format_name_from_pandas(idx_col, u), idx_col[int(u[0])] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we delete the comment, please? :) |
||
DynamicStructureNode(idx_col[int(v[0])], v[-1]), | ||
w, | ||
) | ||
for u, v, w in g.edges.data("weight") | ||
|
@@ -166,7 +167,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments | |
tabu_edges: List[Tuple[int, int, int]] = None, | ||
tabu_parent_nodes: List[int] = None, | ||
tabu_child_nodes: List[int] = None, | ||
) -> StructureModel: | ||
) -> DynamicStructureModel: | ||
""" | ||
Learn the graph structure of a Dynamic Bayesian Network describing conditional dependencies between variables in | ||
data. The input data is time series data present in numpy arrays X and Xlags. | ||
|
@@ -254,7 +255,7 @@ def from_numpy_dynamic( # pylint: disable=too-many-arguments | |
|
||
def _matrices_to_structure_model( | ||
w_est: np.ndarray, a_est: np.ndarray | ||
) -> StructureModel: | ||
) -> DynamicStructureModel: | ||
""" | ||
Converts the matrices output by dynotears (W and A) into a StructureModel | ||
We use the following convention: | ||
|
@@ -268,13 +269,13 @@ def _matrices_to_structure_model( | |
StructureModel representing the structure learnt | ||
|
||
""" | ||
sm = StructureModel() | ||
sm = DynamicStructureModel() | ||
lag_cols = [ | ||
f"{var}_lag{l_val}" | ||
DynamicStructureNode(var, l_val) | ||
for l_val in range(1 + (a_est.shape[0] // a_est.shape[1])) | ||
for var in range(a_est.shape[1]) | ||
] | ||
sm.add_nodes_from(lag_cols) | ||
sm.add_nodes(lag_cols) | ||
sm.add_edges_from( | ||
[ | ||
(lag_cols[i], lag_cols[j], dict(weight=w_est[i, j])) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good.
I think the text should be a bit different from the one in the BN class though. It's ok to keep the similar points, but I would rather say that a DBN is a BN with the time domain taken into account, and it does X and Y that a normal BN doesn't do
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes i think i need a PS here :)