Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

39 refactor mljflux interface #92

Merged
merged 35 commits into from
Jul 16, 2024
Merged

39 refactor mljflux interface #92

merged 35 commits into from
Jul 16, 2024

Conversation

pasq-cat
Copy link
Member

@pasq-cat pasq-cat commented Jun 8, 2024

i have tried to use the new @mlj_model macro.
there are some choices that i am not sure if they are right.
First, the @mlj_model macro allows to define constraints directly in the struct, so there is no need for a clean! function, but if you prefer the older method i can go back.

Second, the mljmodelinterface quick guide tells that in the structs there need to be only hyperparameters, so i removed the builder field and now that chain has to be passed directly to fit!. It seems also logical since one use may want to experiment with different flux models without having to redefine the hyperparameters tied to the Laplace wrapper. However i am not sure that adding an argument to the fit! function is the correct choice since all the examples models do not do it.

third, i removed the two shape and build functions since they are at most a generic utility for the user that is not necessarily connected to the Laplaceredux package.

  1. i have some doubt over the output of the predict function in the laplaceclassification case. In the regression case i picked the mean and the variance provided by laplace and used them to output a guassian distribution with Distributions.jl, but in the classification case MLJInterface says it has to be a UnivariateFinite element but the example provided direct to a broken link. So i left as an output the pseudoprobabilities of the classes.

It works (at least on my pc....), but i am not sure if it respect what MLJ wants and why these automatic checks complains so much. Is it because i didn't add the Project.toml and manifest.toml files?

@pasq-cat pasq-cat linked an issue Jun 8, 2024 that may be closed by this pull request
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
@JuliaTrustworthyAI JuliaTrustworthyAI deleted a comment from github-actions bot Jun 9, 2024
@JuliaTrustworthyAI JuliaTrustworthyAI deleted a comment from github-actions bot Jun 9, 2024
@JuliaTrustworthyAI JuliaTrustworthyAI deleted a comment from github-actions bot Jun 9, 2024
Copy link
Member

@pat-alt pat-alt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try the following:

  • Use glm_predictive_distribution to return actual predictive posterior.
  • Let glm_predictive_distribution always return normal distribution.
  • Add Distributions.jl as dep.
  • In the predict(la::Laplace) function ( ), to get first two moments, simply call mean and std on returned Distribution.
glm_predictive_distribution(la, X) |>
        dist -> (mean(dist), var(dist))

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/mlj_flux.jl Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
@pasq-cat
Copy link
Member Author

pasq-cat commented Jun 12, 2024

@pat-alt i will not work on this anymore without some kind of indication because i am confused and tired. i tried to change it back as it is was before (and adding the required fields as written here) https://github.com/FluxML/MLJFlux.jl/tree/01ad08ebc664f16d9509171866685c14d7bd6e99 but it doesn't work.

@pat-alt
Copy link
Member

pat-alt commented Jun 12, 2024

@pat-alt i will not work on this anymore without some kind of indication because i am confused and tired. i tried to change it back as it is was before (and adding the required fields as written here) https://github.com/FluxML/MLJFlux.jl/tree/01ad08ebc664f16d9509171866685c14d7bd6e99 but it doesn't work.

OK! It looks like the package fails to precompile, so it's hard to tell if anything in the code itself is wrong. I would suggest the following:

  • Test locally ] test.
  • Fix the error (currently related to AbstractResource).

Then tackle the tasks here, as discussed.

Also, try to remember to sometimes apply the linter:

using JuliaFormatter
JuliaFormatter.format(".")

src/baselaplace/predicting.jl Outdated Show resolved Hide resolved
src/baselaplace/predicting.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
… solved one of the issue. the remaining two are still there
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
src/mlj_flux.jl Outdated Show resolved Hide resolved
test/mlj_flux_interfacing.jl Outdated Show resolved Hide resolved
@pat-alt
Copy link
Member

pat-alt commented Jun 24, 2024

@MojiFarmanbar to see

src/mlj_flux.jl Outdated
@@ -426,7 +422,7 @@ function MLJFlux.fit!(

#return cache, report

return (fitresult, report, cache)
return (fitresult, cache, report)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pat-alt @MojiFarmanbar i have changed the position as we said. this have a negative effect during the testing since now i get (training_losses = (),) when i use the same order in mlj_flux_interfacing ( fitresult, cache, _report = MLJBase.fit(model, 0, X, y)) . the order seems correct ( https://github.com/FluxML/MLJFlux.jl/blob/07e01b7f41bc3dd870dd08c7c97114cc45f91f5e/src/mlj_model_interface.jl#L149)

Copy link
Member

@MojiFarmanbar MojiFarmanbar Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo , then the issue was the order of arguments, right? did you try to see if the test cases will pass? it seems it fails because of _report.

Copy link
Member Author

@pasq-cat pasq-cat Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MojiFarmanbar the content does not correspond. cache is not empty despite i didn't use it at all. in third position i pass cache as () but i get the training loss. in second position i pass the training loss through report, but i get a list of numbers that looks like weights. it mess with the update function that does not work as it should

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo, did you check your environment is correctly set? i suppose your package is in dev mode that you could see your changes immediately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo, did you check your environment is correctly set? i suppose your package is in dev mode that you could see your changes immediately?

yes. i have updated the env, moved everything to the new api and now the issue with report is solved. there is however another issue in the classification task with the history of the loss that grows instead of going down

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MojiFarmanbar i think it's better if i close this pull request since it's a mess and perhaps open a different one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo, maybe we can keep this PR still open and you can create a new PR. I understand that you might do a lot of experimentation but make sure you don't loose the track of them, try to isolate your experiments from each other otherwise they might collide and get confusing.
from above, what I would learn:

  • we need to make sure the environment is updated with the changes in the code otherwise we get inconsistent results.
  • Isolating the experiments
    do you agree with me?

Copy link
Member Author

@pasq-cat pasq-cat Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo, maybe we can keep this PR still open and you can create a new PR. I understand that you might do a lot of experimentation but make sure you don't loose the track of them, try to isolate your experiments from each other otherwise they might collide and get confusing. from above, what I would learn:

  • we need to make sure the environment is updated with the changes in the code otherwise we get inconsistent results.
  • Isolating the experiments
    do you agree with me?

yes, but i do experiments because it's not clear how should i proceed with mljflux.update. it needs the old_history to work and the only way to pass it is through cache, but patrick told me to not use it. Maybe some code reviews with some code tips could help me find the errors quickly and move ahead. besides, the original project was about adding conformalized bayes support to conformalprediction.jl , not how to add to mljbase support to laplaceredux.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look at this tomorrow morning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides, the original project was about adding conformalized bayes support to conformalprediction.jl , not how to add to mljbase support to laplaceredux.

You're free to use any other library for Bayesian models, it doesn't need to be LaplaceRedux. But ConformalPrediction interfaces MLJ, that's a conscious design choice, so support for adding conformalized Bayes was always going to involve working with MLJ.

@pat-alt
Copy link
Member

pat-alt commented Jul 1, 2024

I think this is actually pretty close now 👍🏽 collecting a few observations below as I'm working on the code:

  • Issues seem to be mostly related to logging.

Edit: more specifically, it appears that calling update after increasing epochs by 3 reruns the model for epochs + 3 epochs, as opposed to just 3 more epochs. This is because MLJModelInterface.is_same_except(...) returns false (I'm investigating why exactly).

  • Loss sometimes skyrockets (not obvious why); investigate also for classification.
  • The chain object starts out as just the chain and then at some point gets turned in to (chain, builder) (will double-check if this is expected behavior of MLJFlux).
julia> MLJBase.update(model, 2, chain, cache, X, y)
[ Info: Loss is 12.27
[ Info: Loss is 0.4794
[ Info: Loss is 0.4555
[ Info: Loss is 0.4437
[ Info: Loss is 2.295
[ Info: Loss is 0.455
[ Info: Loss is 0.4554
[ Info: Loss is 0.4562
[ Info: Loss is 170900.0
[ Info: Loss is 457.6
[ Info: Loss is 43.38
[ Info: Loss is 1.24
[ Info: Loss is 0.9924
[ Info: From train
Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1))
[ Info: From fitresult
Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1))
[ Info: From fitresult
(Chain(Dense(4 => 16, relu), Dense(16 => 8, relu), Dense(8 => 1)), LaplaceRegression(builder = MLP(hidden = (16, 8), ), ))

@pat-alt
Copy link
Member

pat-alt commented Jul 1, 2024

@Rockdeldiablo let me take over here for a moment, think it's just a few more minor issues that I can hopefully fix. Then that gives you an opportunity to finish the remaining tasks on #97 😃

P₀=model.P₀,
)
verbose_laplace = false
if !isa(chain, AbstractLaplace)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pat-alt why this check? when does chain turns in a laplace object to require this check?

Copy link
Member

@pat-alt pat-alt Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Rockdeldiablo, please see my comment below, that should clarify this.

Edit: There might be a cleaner way to go about this that avoids the conditional, but I think I have convinced myself that returning the whole Laplace object is the way to go, because it contains the chain. The downside is that this may lead to unexpected behavior in the future if MLJFlux adds downstream functionality foo(chain) for the chain returned by train, that the Laplace object does not support. But since it contains the chain we can always address such cases swiftly foo(la) = foo(la.model) (where la.model is the chain).

@pat-alt
Copy link
Member

pat-alt commented Jul 1, 2024

@Rockdeldiablo I've added the following changes now to make things work without having to overload the update method or changing it in MLJFlux:

  • The train method now returns la, optimiser_state, history where la is the Laplace object. This way, the object does not need to be stored as a field of the struct and the problem with update is avoided.
  • To facilitate this change, calling a Laplace object on an array, (la::AbstractLaplace)(X::AbstractArray) now simply calls the underlying neural network on data. In other words, it returns the generic predictions, not LA predictions.
  • The fitresult method was adjusted also for the classification case.

Now that tests are passing, there are a few more things to do (possibly in a new issue + PR) if you like.

  • Add a (short) tutorial to the documentation.
  • Double-check if the code in src/mlj_flux.jl can be streamlined further (e.g. do we actually still need to overload MLJFlux.build)?

For now, feel free to to focus on the other PR, just ping me and @MojiFarmanbar when you come back to this one. I need to move on to other things for now.

@pat-alt
Copy link
Member

pat-alt commented Jul 10, 2024

Let's just move the pending tasks above into new issues and then merge this one.

@pasq-cat
Copy link
Member Author

pasq-cat commented Jul 16, 2024

Let's just move the pending tasks above into new issues and then merge this one.

ahh i just saw this message. how do i move the pending tasks in new issues?

@pasq-cat pasq-cat mentioned this pull request Jul 16, 2024
2 tasks
@pasq-cat pasq-cat marked this pull request as ready for review July 16, 2024 07:11
@pat-alt pat-alt merged commit 3e1531d into main Jul 16, 2024
10 of 11 checks passed
@pasq-cat pasq-cat deleted the 39-refactor-mljflux-interface branch July 16, 2024 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor MLJFlux interface
3 participants