Skip to content

Commit

Permalink
Make predict accept NamedTuple and return a DataFrame (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkborregaard authored and nalimilan committed Jan 27, 2019
1 parent a28838f commit c725015
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions src/statsmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ end
"""
drop_intercept(::Type)
Define whether a given model automatically drops the intercept. Return `false` by default.
To specify that a model type `T` drops the intercept, overload this function for the
Define whether a given model automatically drops the intercept. Return `false` by default.
To specify that a model type `T` drops the intercept, overload this function for the
corresponding type: `drop_intercept(::Type{T}) = true`
Models that drop the intercept will be fitted without one: the intercept term will be
removed even if explicitly provided by the user. Categorical variables will be expanded
Models that drop the intercept will be fitted without one: the intercept term will be
removed even if explicitly provided by the user. Categorical variables will be expanded
in the rank-reduced form (contrasts for `n` levels will only produce `n-1` columns).
"""
drop_intercept(::Type) = false
Expand Down Expand Up @@ -89,6 +89,28 @@ StatsBase.adjr2(mm::DataFrameRegressionModel) = adjr2(mm.model)
StatsBase.r2(mm::DataFrameRegressionModel, variant::Symbol) = r2(mm.model, variant)
StatsBase.adjr2(mm::DataFrameRegressionModel, variant::Symbol) = adjr2(mm.model, variant)

function _return_predictions(yp::AbstractVector, nonmissings, len)
out = missings(eltype(yp), len)
out[nonmissings] = yp
out
end

function _return_predictions(yp::AbstractMatrix, nonmissings, len)
out = missings(eltype(yp), (len, 3))
out[nonmissings, :] = yp
DataFrame(prediction = out[:,1], lower = out[:,2], upper = out[:,3])
end

function _return_predictions(yp::NamedTuple, nonmissings, len)
y = missings(eltype(yp[:prediction]), len)
l, h = similar(y), similar(y)
out = (prediction = y, lower = l, upper = h)
for key in (:prediction, :lower, :upper)
out[key][nonmissings] = yp[key]
end
DataFrame(out)
end

# Predict function that takes data frame as predictor instead of matrix
function StatsBase.predict(mm::DataFrameRegressionModel{T}, df::AbstractDataFrame; kwargs...) where T
# copy terms, removing outcome if present (ModelFrame will complain if a
Expand All @@ -100,9 +122,7 @@ function StatsBase.predict(mm::DataFrameRegressionModel{T}, df::AbstractDataFram
drop_intercept(T) && (mf.terms.intercept = false)
newX = ModelMatrix(mf).m
yp = predict(mm, newX; kwargs...)
out = missings(eltype(yp), size(df, 1))
out[mf.nonmissing] = yp
return(out)
_return_predictions(yp, mf.nonmissing, size(df, 1))
end

StatsBase.coefnames(model::DataFrameModels) = coefnames(model.mf)
Expand Down

0 comments on commit c725015

Please sign in to comment.