Skip to content

Commit

Permalink
Add prediction type to return the mean, variance, and (if implemented…
Browse files Browse the repository at this point in the history
…) mode

Additionally, this change avoids unnecessary sampling if the prediction type doesn't need it.
  • Loading branch information
neverfox committed Jun 28, 2023
1 parent a99ae90 commit 0bf9342
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions lightgbmlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def predict_dist(self,
- "quantile" calculates the quantiles from the predicted distribution.
- "parameters" returns the predicted distributional parameters.
- "expectiles" returns the predicted expectiles.
- "properties" returns the mean, variance, and (if implemented) mode.
n_samples : int
Number of samples to draw from the predicted distribution.
quantiles : List[float]
Expand Down Expand Up @@ -402,18 +403,35 @@ def predict_dist(self,
dist_params_predt = pd.DataFrame(dist_params_predt)
dist_params_predt.columns = self.param_dict.keys()

# Draw samples from predicted response distribution
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
n_samples=n_samples,
seed=seed)

if pred_type == "parameters":
return dist_params_predt

elif pred_type == "expectiles":
return dist_params_predt

elif pred_type == "properties":
if self.tau is None:
pred_params = torch.tensor(dist_params_predt.values)
dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)}
dist_pred = self.distribution(**dist_kwargs)
pred_props = pd.DataFrame({"mean": dist_pred.mean.detach().numpy(),
"variance": dist_pred.variance.detach().numpy()})
try:
dist_pred.mode
except NotImplementedError:
pass
else:
pred_props["mode"] = dist_pred.mode.detach().numpy()
return pred_props
else:
raise ValueError("Invalid prediction type.")

# Draw samples from predicted response distribution
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
n_samples=n_samples,
seed=seed)

elif pred_type == "samples":
if pred_type == "samples":
return pred_samples_df

elif pred_type == "quantiles":
Expand Down

0 comments on commit 0bf9342

Please sign in to comment.