-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
spring cleaning: removing bugs and code smells #261
Changes from 16 commits
9ac8b34
178592c
159de62
f96cf3d
47a7cef
8e6fa1a
645fa01
5652090
a455dd9
27daf4e
cd73acd
73f7b9f
448598b
69d6aab
6436e20
fc62d4a
4aaa8cd
bd6dbdc
85b5ea9
19212bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,45 +66,47 @@ def initializeBold(self): | |
self.boldInitialized = True | ||
# logging.info(f"{self.name}: BOLD model initialized.") | ||
|
||
def simulateBold(self, t, variables, append=False): | ||
def get_bold_variable(self, variables): | ||
default_index = self.state_vars.index(self.default_output) | ||
return variables[default_index] | ||
|
||
def simulateBold(self, bold_variable, append=False): | ||
"""Gets the default output of the model and simulates the BOLD model. | ||
Adds the simulated BOLD signal to outputs. | ||
""" | ||
if self.boldInitialized: | ||
# first we loop through all state variables | ||
for svn, sv in zip(self.state_vars, variables): | ||
# the default output is used as the input for the bold model | ||
if svn == self.default_output: | ||
bold_input = sv[:, self.startindt :] | ||
# logging.debug(f"BOLD input `{svn}` of shape {bold_input.shape}") | ||
if bold_input.shape[1] >= self.boldModel.samplingRate_NDt: | ||
# only if the length of the output has a zero mod to the sampling rate, | ||
# the downsampled output from the boldModel can correctly appended to previous data | ||
# so: we are lazy here and simply disable appending in that case ... | ||
if not bold_input.shape[1] % self.boldModel.samplingRate_NDt == 0: | ||
append = False | ||
logging.warn( | ||
f"Output size {bold_input.shape[1]} is not a multiple of BOLD sampling length { self.boldModel.samplingRate_NDt}, will not append data." | ||
) | ||
logging.debug(f"Simulating BOLD: boldModel.run(append={append})") | ||
|
||
# transform bold input according to self.boldInputTransform | ||
if self.boldInputTransform: | ||
bold_input = self.boldInputTransform(bold_input) | ||
|
||
# simulate bold model | ||
self.boldModel.run(bold_input, append=append) | ||
|
||
t_BOLD = self.boldModel.t_BOLD | ||
BOLD = self.boldModel.BOLD | ||
self.setOutput("BOLD.t_BOLD", t_BOLD) | ||
self.setOutput("BOLD.BOLD", BOLD) | ||
else: | ||
logging.warn( | ||
f"Will not simulate BOLD if output {bold_input.shape[1]*self.params['dt']} not at least of duration {self.boldModel.samplingRate_NDt*self.params['dt']}" | ||
) | ||
else: | ||
if not self.boldInitialized: | ||
logging.warn("BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`") | ||
return | ||
|
||
bold_input = bold_variable[:, self.startindt :] | ||
# logging.debug(f"BOLD input `{svn}` of shape {bold_input.shape}") | ||
if not bold_input.shape[1] >= self.boldModel.samplingRate_NDt: | ||
logging.warn( | ||
f"Will not simulate BOLD if output {bold_input.shape[1]*self.params['dt']} not at least of duration {self.boldModel.samplingRate_NDt*self.params['dt']}" | ||
) | ||
return | ||
|
||
# only if the length of the output has a zero mod to the sampling rate, | ||
# the downsampled output from the boldModel can correctly appended to previous data | ||
# so: we are lazy here and simply disable appending in that case ... | ||
if append and not bold_input.shape[1] % self.boldModel.samplingRate_NDt == 0: | ||
append = False | ||
logging.warn( | ||
f"Output size {bold_input.shape[1]} is not a multiple of BOLD sampling length { self.boldModel.samplingRate_NDt}, will not append data." | ||
) | ||
logging.debug(f"Simulating BOLD: boldModel.run()") | ||
|
||
# transform bold input according to self.boldInputTransform | ||
if self.boldInputTransform: | ||
bold_input = self.boldInputTransform(bold_input) | ||
|
||
# simulate bold model | ||
self.boldModel.run(bold_input) | ||
|
||
t_BOLD = self.boldModel.t_BOLD | ||
BOLD = self.boldModel.BOLD | ||
self.setOutput("BOLD.t_BOLD", t_BOLD, append=append) | ||
self.setOutput("BOLD.BOLD", BOLD, append=append) | ||
|
||
def checkChunkwise(self, chunksize): | ||
"""Checks if the model fulfills requirements for chunkwise simulation. | ||
|
@@ -172,21 +174,16 @@ def initializeRun(self, initializeBold=False): | |
# check dt / sampling_dt | ||
self.setSamplingDt() | ||
|
||
# force bold if params['bold'] == True | ||
if self.params.get("bold"): | ||
initializeBold = True | ||
# set up the bold model, if it didn't happen yet | ||
if initializeBold and not self.boldInitialized: | ||
self.initializeBold() | ||
|
||
def run( | ||
self, | ||
inputs=None, | ||
chunkwise=False, | ||
chunksize=None, | ||
bold=False, | ||
append=False, | ||
append_outputs=None, | ||
append_outputs=False, | ||
continue_run=False, | ||
): | ||
""" | ||
|
@@ -205,28 +202,24 @@ def run( | |
:type chunksize: int, optional | ||
:param bold: simulate BOLD signal (only for chunkwise integration), defaults to False | ||
:type bold: bool, optional | ||
:param append: append the chunkwise outputs to the outputs attribute, defaults to False, defaults to False | ||
:type append: bool, optional | ||
:param append_outputs: append new and chunkwise outputs to the outputs attribute, defaults to False. Note: BOLD outputs are always appended | ||
:type append_outputs: bool, optional | ||
:param continue_run: continue a simulation by using the initial values from a previous simulation | ||
:type continue_run: bool | ||
""" | ||
# TODO: legacy argument support | ||
if append_outputs is not None: | ||
append = append_outputs | ||
self.initializeRun(initializeBold=bold) | ||
|
||
# if a previous run is not to be continued clear the model's state | ||
if continue_run is False: | ||
if continue_run: | ||
self.setInitialValuesToLastState() | ||
else: | ||
self.clearModelState() | ||
|
||
self.initializeRun(initializeBold=bold) | ||
|
||
# enable chunkwise if chunksize is set | ||
chunkwise = chunkwise if chunksize is None else True | ||
|
||
if chunkwise is False: | ||
self.integrate(append_outputs=append, simulate_bold=bold) | ||
if continue_run: | ||
self.setInitialValuesToLastState() | ||
self.integrate(append_outputs=append_outputs, simulate_bold=bold) | ||
|
||
else: | ||
if chunksize is None: | ||
|
@@ -235,10 +228,8 @@ def run( | |
# check if model is safe for chunkwise integration | ||
# and whether sampling_dt is compatible with duration and chunksize | ||
self.checkChunkwise(chunksize) | ||
if bold and not self.boldInitialized: | ||
logging.warn(f"{self.name}: BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`") | ||
bold = False | ||
self.integrateChunkwise(chunksize=chunksize, bold=bold, append_outputs=append) | ||
|
||
self.integrateChunkwise(chunksize=chunksize, bold=bold, append_outputs=append_outputs) | ||
|
||
# check if there was a problem with the simulated data | ||
self.checkOutputs() | ||
|
@@ -260,20 +251,17 @@ def checkOutputs(self): | |
def integrate(self, append_outputs=False, simulate_bold=False): | ||
"""Calls each models `integration` function and saves the state and the outputs of the model. | ||
|
||
:param append: append the chunkwise outputs to the outputs attribute, defaults to False, defaults to False | ||
:param append: append the chunkwise outputs to the outputs attribute, defaults to False | ||
:type append: bool, optional | ||
""" | ||
# run integration | ||
t, *variables = self.integration(self.params) | ||
self.storeOutputsAndStates(t, variables, append=append_outputs) | ||
|
||
# force bold if params['bold'] == True | ||
if self.params.get("bold"): | ||
simulate_bold = True | ||
|
||
# bold simulation after integration | ||
if simulate_bold and self.boldInitialized: | ||
self.simulateBold(t, variables, append=True) | ||
bold_variable = self.get_bold_variable(variables) | ||
self.simulateBold(bold_variable, append=True) | ||
|
||
def integrateChunkwise(self, chunksize, bold=False, append_outputs=False): | ||
"""Repeatedly calls the chunkwise integration for the whole duration of the simulation. | ||
|
@@ -311,7 +299,7 @@ def clearModelState(self): | |
self.state = dotdict({}) | ||
self.outputs = dotdict({}) | ||
# reinitialize bold model | ||
if self.params.get("bold"): | ||
if self.boldInitialized: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this do the same thing as before? Reads like: if it is initialized, initialize it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that BOLD is now only initialized at the beginning of the first run. One only needs to clear the bold state with re-initialization if it has been initialized before. |
||
self.initializeBold() | ||
|
||
def storeOutputsAndStates(self, t, variables, append=False): | ||
|
@@ -335,6 +323,8 @@ def storeOutputsAndStates(self, t, variables, append=False): | |
|
||
def setInitialValuesToLastState(self): | ||
"""Reads the last state of the model and sets the initial conditions to that state for continuing a simulation.""" | ||
if not hasattr(self, "t"): | ||
raise ValueError("You tried using continue_run=True on the first run.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to error here? The user could user There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed this in the latest commit |
||
for iv, sv in zip(self.init_vars, self.state_vars): | ||
# if state variables are one-dimensional (in space only) | ||
if (self.state[sv].ndim == 0) or (self.state[sv].ndim == 1): | ||
|
@@ -474,25 +464,28 @@ def setOutput(self, name, data, append=False, removeICs=False): | |
raise ValueError(f"Don't know how to truncate data of shape {data.shape}.") | ||
|
||
# subsample to sampling dt | ||
if data.ndim == 1: | ||
data = data[:: self.sample_every] | ||
elif data.ndim == 2: | ||
data = data[:, :: self.sample_every] | ||
else: | ||
raise ValueError(f"Don't know how to subsample data of shape {data.shape}.") | ||
if data.shape[-1] >= self.params["duration"] - self.startindt: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This catches unintended subsampling of BOLD. |
||
if data.ndim == 1: | ||
data = data[:: self.sample_every] | ||
elif data.ndim == 2: | ||
data = data[:, :: self.sample_every] | ||
else: | ||
raise ValueError(f"Don't know how to subsample data of shape {data.shape}.") | ||
|
||
def save_leaf(node, name, data, append): | ||
if name in node: | ||
if data.ndim == 1 and name == "t": | ||
# special treatment for time data: | ||
# increment the time by the last recorded duration | ||
data += node[name][-1] | ||
if append and data.shape[-1] != 0: | ||
data = np.hstack((node[name], data)) | ||
node[name] = data | ||
return node | ||
|
||
# if the output is a single name (not dot.separated) | ||
if "." not in name: | ||
# append data | ||
if append and name in self.outputs: | ||
# special treatment for time data: | ||
# increment the time by the last recorded duration | ||
if name == "t": | ||
data += self.outputs[name][-1] | ||
self.outputs[name] = np.hstack((self.outputs[name], data)) | ||
else: | ||
# save all data into output dict | ||
self.outputs[name] = data | ||
save_leaf(self.outputs, name, data, append) | ||
# set output as an attribute | ||
setattr(self, name, self.outputs[name]) | ||
else: | ||
|
@@ -503,18 +496,10 @@ def setOutput(self, name, data, append=False, removeICs=False): | |
for i, k in enumerate(keys): | ||
# if it's the last iteration, store data | ||
if i == len(keys) - 1: | ||
# TODO: this needs to be append-aware like above | ||
# if append: | ||
# if k == "t": | ||
# data += level[k][-1] | ||
# level[k] = np.hstack((level[k], data)) | ||
# else: | ||
# level[k] = data | ||
level[k] = data | ||
level = save_leaf(level, k, data, append) | ||
# if key is in outputs, then go deeper | ||
elif k in level: | ||
level = level[k] | ||
setattr(self, k, level) | ||
# if it's a new key, create new nested dictionary, set attribute, then go deeper | ||
else: | ||
level[k] = dotdict({}) | ||
|
@@ -604,11 +589,9 @@ def xr(self, group=""): | |
assert len(timeDictKey) > 0, f"No time array found (starting with t) in output group {group}." | ||
t = outputDict[timeDictKey].copy() | ||
del outputDict[timeDictKey] | ||
outputs = [] | ||
outputNames = [] | ||
for key, value in outputDict.items(): | ||
outputNames.append(key) | ||
outputs.append(value) | ||
|
||
outputNames, outputs = zip(*outputDict.items()) | ||
outputNames = list(outputNames) | ||
|
||
nNodes = outputs[0].shape[0] | ||
nodes = list(range(nNodes)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You seem to have the
append=True
case, why?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functionality for saving outputs is already in
models/model.py
.