From 0bb7a3559d604b878c298d52bc5019478fe95ce7 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 23 Dec 2023 09:39:38 -0500 Subject: [PATCH 1/8] Add initial optimisers.jl scheduler --- Project.toml | 4 +- src/ParameterSchedulers.jl | 83 ++------------------------------------ src/scheduler.jl | 56 +++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 81 deletions(-) create mode 100644 src/scheduler.jl diff --git a/Project.toml b/Project.toml index 19f3119..c6addb2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,13 +4,13 @@ authors = ["Kyle Daruwalla"] version = "0.3.7" [deps] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" [compat] -Flux = "0.11.2, 0.12, 0.13, 0.14" InfiniteArrays = "0.10.4, 0.11, 0.12, 0.13" julia = "1.6" +Optimisers = "0.3.1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/ParameterSchedulers.jl b/src/ParameterSchedulers.jl index 9a15822..9f2094d 100644 --- a/src/ParameterSchedulers.jl +++ b/src/ParameterSchedulers.jl @@ -1,8 +1,9 @@ module ParameterSchedulers using Base.Iterators -using Flux using InfiniteArrays: OneToInf +using Optimisers: AbstractRule +import Optimisers include("interface.jl") @@ -19,83 +20,7 @@ export Sequence, Loop, Interpolator, Shifted, ComposedSchedule include("utils.jl") -# TODO -# Remove this once Optimisers.jl has support -# for schedules + optimizers -""" - Scheduler{T, O, F}(schedule::AbstractSchedule, opt, update_func) - Scheduler(schedule, opt; update_func = (o, s) -> (o.eta = s)) - -Wrap a `schedule` and `opt` together with a `Scheduler`. -The `schedule` is iterated on every call to -[`Flux.apply!`](https://github.com/FluxML/Flux.jl/blob/master/src/optimise/optimisers.jl). -The `Scheduler` can be used anywhere a Flux optimizer is used. - -By default, the learning rate (i.e. `opt.eta`) is scheduled. -Set `update_func = (opt, schedule_val) -> ...` to schedule an alternate field. -If `opt` does not have a field `eta`, then there is no default behavior -(you must manually set `update_func`). - -# Arguments -- `schedule`: the schedule to use -- `opt`: a Flux optimizer -- `update_func`: a mutating function of with inputs `(optim, param)` - that mutates `optim`'s fields based on the current `param` value - -# Examples -```julia -# cosine annealing schedule for Descent -julia> s = CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10); - -julia> opt = Scheduler(s, Descent()) -Scheduler(CosAnneal{Float64,Int64}(0.1, 0.8, 10), Descent(0.1)) - -# schedule the momentum term of Momentum -julia> opt = Scheduler(s, Momentum(); update_func = (o, s) -> o.rho = s) -Scheduler(CosAnneal{Float64,Int64}(0.1, 0.8, 10), Momentum(0.01, 0.9, IdDict{Any,Any}())) -``` -""" -mutable struct Scheduler{T, O, F} <: Flux.Optimise.AbstractOptimiser - state::IdDict{Any, Int} - schedule::T - optim::O - update_func::F - - function Scheduler(state::IdDict{Any, Int}, - schedule::T, - optim::O, - update_func::F) where {T, O, F} - Base.depwarn("""`Scheduler` will transition to explicit Optimisers.jl style - optimizers in the next release""", :Scheduler) - - return new{T, O, F}(state, schedule, optim, update_func) - end -end -Scheduler(schedule, opt, update_func) = - Scheduler(IdDict{Any, Int}(), schedule, opt, update_func) - -Base.show(io::IO, s::Scheduler) = - print(io, "Scheduler(", s.schedule, ", ", s.optim, ")") - -function Flux.Optimise.apply!(opt::Scheduler, x, Δ) - # get iteration - t = get!(opt.state, x, 1) - opt.state[x] = t + 1 - - # set param - opt.update_func(opt.optim, opt.schedule(t)) - - # do normal apply - return Flux.Optimise.apply!(opt.optim, x, Δ) -end - -for Opt in (Descent, Momentum, Nesterov, RMSProp, - Adam, RAdam, AdaMax, OAdam, AdaGrad, - AdaDelta, AMSGrad, NAdam, AdaBelief) - @eval begin - Scheduler(schedule, opt::$Opt; update_func = (o, s) -> (o.eta = s)) = - Scheduler(schedule, opt, update_func) - end -end +include("scheduler.jl") +export Scheduler end \ No newline at end of file diff --git a/src/scheduler.jl b/src/scheduler.jl new file mode 100644 index 0000000..fc6a6a1 --- /dev/null +++ b/src/scheduler.jl @@ -0,0 +1,56 @@ +""" + Scheduler{T, F} <: Optimiser.AbstractRule + Scheduler(constructor, schedules::AbstractSchedule...) + Scheduler(constructor; field_a = schedule_a, field_b = schedule_b, ...) + +Wrap one or more schedules and optimizer together with a `Scheduler`. +On each call to [`Optimisers.apply!`](@ref Optimisers.apply!), the schedules +are iterated and `constructor` is used to invoke an optimization rule with +updated parameters. +The `Scheduler` can be used anywhere an Optimisers.jl optimizer is used. + +If passed a single schedule and optimizer rule, the scheduler updates the +learning, `opt.eta`. +To adjust multiple hyperparameters, pass in multiple schedules as arguments or +keywords. These will be iterated in order and passed onto to `constructor` +(i.e. `constructor` should accept the appropriate number of arguments/keywords). + +# Arguments +- `constructor`: a constructor that creates an optimization rule given some + parameters (e.g. `Optimisers.AdamW`; note the lack of `()`) +- `schedules`: the list of optimization rule hyperparameters to schedule as + multiple (named) arguments + +# Examples +```julia +# cosine annealing schedule for Descent +julia> opt = Scheduler(Descent, CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10)); + +# schedule learning rate and momentum of Momentum +julia> opt = Scheduler(Momentum, CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10), Exp(0.999, 0.8)); + +# schedule the weight decay term of AdamW +julia> opt = Scheduler(AdamW, decay = Exp(1e-3, 0.7)); +``` +""" +struct Scheduler{T<:Union{<:Tuple, <:NamedTuple}, F} <: AbstractRule + constructor::F + schedules::T +end +Scheduler(constructor, schedules...) = Scheduler(constructor, schedules) +Scheduler(constructor; schedules...) = Scheduler(constructor, schedules) + +_get_opt(scheduler::Scheduler{<:Tuple}, t) = + scheduler.constructor((s(t) for s in schedules)...) +_get_opt(scheduler::Scheduler{<:NamedTuple}, t) = + scheduler.constructor(NamedTuple{keys(schedules)}(s(t) for s in schedules)...) + +Optimisers.init(o::Scheduler, x::AbstractArray) = + (t = 1, opt = Optimisers.init(_get_opt(o, 1), x)) + +function Optimisers.apply!(o::Scheduler, state, x, dx) + opt = _get_opt(o, state.t) + new_state, new_dx = Optimisers.apply!(opt, state.opt, x, dx) + + return (t = state.t + 1, opt = new_state), new_dx +end From d220801a7d36eac92c509bdd3b54f3ff8e917c21 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 23 Dec 2023 18:19:32 -0500 Subject: [PATCH 2/8] Add tests and remove dep warnings --- Project.toml | 5 +++-- src/cyclic.jl | 50 ++++++------------------------------------------ src/scheduler.jl | 11 +++++++---- test/runtests.jl | 38 +++++++++++++++++++++++++----------- 4 files changed, 43 insertions(+), 61 deletions(-) diff --git a/Project.toml b/Project.toml index c6addb2..248f408 100644 --- a/Project.toml +++ b/Project.toml @@ -9,10 +9,11 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" [compat] InfiniteArrays = "0.10.4, 0.11, 0.12, 0.13" -julia = "1.6" Optimisers = "0.3.1" +julia = "1.6" [extras] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [publish] @@ -21,4 +22,4 @@ theme = "_flux-theme" title = "ParameterSchedulers.jl" [targets] -test = ["Test"] +test = ["Flux", "Test"] diff --git a/src/cyclic.jl b/src/cyclic.jl index 44217ee..704a983 100644 --- a/src/cyclic.jl +++ b/src/cyclic.jl @@ -23,13 +23,8 @@ struct Triangle{T, S<:Integer} <: AbstractSchedule{false} offset::T period::S end -function Triangle(range::T, offset::T, period::S) where {T, S} - @warn """Triangle(range0, range1, period) is now Triangle(range, offset, period). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:tri) maxlog=1 - +Triangle(range::T, offset::T, period::S) where {T, S} = Triangle{T, S}(range, offset, period) -end Triangle(;λ0, λ1, period) = Triangle(abs(λ0 - λ1), min(λ0, λ1), period) Base.eltype(::Type{<:Triangle{T}}) where T = T @@ -54,13 +49,7 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period)) - `range1`/`λ1`: the second range endpoint - `period::Integer`: the period """ -function TriangleDecay2(range, offset, period) - @warn """TriangleDecay2(range0, range1, period) is now TriangleDecay2(range, offset, period). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:tri) maxlog=1 - - return _tridecay2(range, offset, period) -end +TriangleDecay2(range, offset, period) = _tridecay2(range, offset, period) TriangleDecay2(;λ0, λ1, period) = _tridecay2(abs(λ0 - λ1), min(λ0, λ1), period) function _tridecay2(range::T, offset, period) where T @@ -89,13 +78,7 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period)) - `period::Integer`: the period - `decay`/`γ`: the decay rate """ -function TriangleExp(range, offset, period, γ) - @warn """TriangleExp(range0, range1, period, γ) is now TriangleExp(range, offset, period, γ). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:tri) maxlog=1 - - return _triexp(range, offset, period, γ) -end +TriangleExp(range, offset, period, γ) = _triexp(range, offset, period, γ) TriangleExp(;λ0, λ1, period, γ) = _triexp(abs(λ0 - λ1), min(λ0, λ1), period, γ) _triexp(range, offset, period, γ) = @@ -121,13 +104,7 @@ struct Sin{T, S<:Integer} <: AbstractSchedule{false} offset::T period::S end -function Sin(range::T, offset::T, period::S) where {T, S} - @warn """Sin(range0, range1, period) is now Sin(range, offset, period). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:sine) maxlog=1 - - Sin{T, S}(range, offset, period) -end +Sin(range::T, offset::T, period::S) where {T, S} = Sin{T, S}(range, offset, period) Sin(;λ0, λ1, period) = Sin(abs(λ0 - λ1), min(λ0, λ1), period) Base.eltype(::Type{<:Sin{T}}) where T = T @@ -150,13 +127,7 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)). - `offset == min(λ0, λ1)`: the offset / minimum value - `period::Integer`: the period """ -function SinDecay2(range, offset, period) - @warn """SinDecay2(range0, range1, period) is now SinDecay2(range, offset, period). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:sine) maxlog=1 - - return _sindecay2(range, offset, period) -end +SinDecay2(range, offset, period) = _sindecay2(range, offset, period) SinDecay2(;λ0, λ1, period) = _sindecay2(abs(λ0 - λ1), min(λ0, λ1), period) function _sindecay2(range::T, offset, period) where T @@ -182,13 +153,7 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)). - `period::Integer`: the period - `γ`: the decay rate """ -function SinExp(range, offset, period, γ) - @warn """SinExp(range0, range1, period, γ) is now SinExp(range, offset, period, γ). - To specify by endpoints, use the keyword argument form. - This message will be removed in the next version.""" _id=(:sine) maxlog=1 - - return _sinexp(range, offset, period, γ) -end +SinExp(range, offset, period, γ) = _sinexp(range, offset, period, γ) SinExp(;λ0, λ1, period, γ) = _sinexp(abs(λ0 - λ1), min(λ0, λ1), period, γ) _sinexp(range, offset, period, γ) = @@ -231,6 +196,3 @@ function (schedule::CosAnneal)(t) return schedule.range * (1 + cos(π * t̂ / schedule.period)) / 2 + schedule.offset end - -Base.@deprecate Cos(range0, range1, period) CosAnneal(λ0 = range0, λ1 = range1, period = period) -Base.@deprecate Cos(;λ0, λ1, period) CosAnneal(λ0 = λ0, λ1 = λ1, period = period) diff --git a/src/scheduler.jl b/src/scheduler.jl index fc6a6a1..3493819 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -38,12 +38,15 @@ struct Scheduler{T<:Union{<:Tuple, <:NamedTuple}, F} <: AbstractRule schedules::T end Scheduler(constructor, schedules...) = Scheduler(constructor, schedules) -Scheduler(constructor; schedules...) = Scheduler(constructor, schedules) +Scheduler(constructor; schedules...) = Scheduler(constructor, (; schedules...)) _get_opt(scheduler::Scheduler{<:Tuple}, t) = - scheduler.constructor((s(t) for s in schedules)...) -_get_opt(scheduler::Scheduler{<:NamedTuple}, t) = - scheduler.constructor(NamedTuple{keys(schedules)}(s(t) for s in schedules)...) + scheduler.constructor((s(t) for s in scheduler.schedules)...) +function _get_opt(scheduler::Scheduler{<:NamedTuple}, t) + kwargs = NamedTuple{keys(scheduler.schedules)}(s(t) for s in scheduler.schedules) + + return scheduler.constructor(kwargs...) +end Optimisers.init(o::Scheduler, x::AbstractArray) = (t = 1, opt = Optimisers.init(_get_opt(o, 1), x)) diff --git a/test/runtests.jl b/test/runtests.jl index c8ccd9d..4c7fe0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using ParameterSchedulers using Flux +using Optimisers using Test using InfiniteArrays: OneToInf @@ -15,17 +16,32 @@ end include("complex.jl") end @testset "Scheduler" begin - using ParameterSchedulers: Scheduler - m = Chain(Dense(10, 5), Dense(5, 2)) - ps = Flux.params(m) - s = Exp(0.1, 0.5) - o = Scheduler(s, Momentum()) - for t in 1:10 - g = Flux.gradient(() -> sum(m(rand(Float32, 10, 2))), ps) - Flux.update!(o, ps, g) - @test o.optim.eta == s(t) - for p in ps - @test o.state[p] == t + 1 + @testset "Basic usage" begin + m = (W = ones(Float32, 4, 3), b = ones(Float32, 4)) + s = Exp(0.1, 0.5) + o = Flux.setup(Scheduler(Optimisers.Descent, s), m) + x = ones(Float32, 3) + for t in 1:10 + g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] + o, m′ = Optimisers.update(o, m, g) + @test m′.W ≈ m.W - g.W * s(t) + @test m′.b ≈ m.b - g.b * s(t) + m = m′ + end + end + @testset "Advanced usage" begin + m = (W = ones(Float32, 4, 3), b = ones(Float32, 4)) + seta = Exp(0.1, 0.5) + srho = Exp(0.9, 0.9) + o = Flux.setup(Scheduler(Optimisers.Momentum, eta = seta, rho = srho), m) + x = ones(Float32, 3) + for t in 1:10 + g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] + o′, m′ = Optimisers.update(o, m, g) + @test m′.W ≈ m.W - (srho(t) * o.W.state.opt + g.W * seta(t)) + @test m′.b ≈ m.b - (srho(t) * o.b.state.opt + g.b * seta(t)) + m = m′ + o = o′ end end end From decfc6d8e71cfd57d6ef46732322c633ac36a00f Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 23 Dec 2023 18:20:38 -0500 Subject: [PATCH 3/8] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 248f408..0dd984a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterSchedulers" uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" authors = ["Kyle Daruwalla"] -version = "0.3.7" +version = "0.4.0" [deps] InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c" From 07ed093e73233183579aac8a5712f5018c77b139 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 23 Dec 2023 18:38:37 -0500 Subject: [PATCH 4/8] Use Optimisers.setup instead --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4c7fe0b..ba67fde 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ end @testset "Basic usage" begin m = (W = ones(Float32, 4, 3), b = ones(Float32, 4)) s = Exp(0.1, 0.5) - o = Flux.setup(Scheduler(Optimisers.Descent, s), m) + o = Optimisers.setup(Scheduler(Optimisers.Descent, s), m) x = ones(Float32, 3) for t in 1:10 g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] @@ -33,7 +33,7 @@ end m = (W = ones(Float32, 4, 3), b = ones(Float32, 4)) seta = Exp(0.1, 0.5) srho = Exp(0.9, 0.9) - o = Flux.setup(Scheduler(Optimisers.Momentum, eta = seta, rho = srho), m) + o = Optimisers.setup(Scheduler(Optimisers.Momentum, eta = seta, rho = srho), m) x = ones(Float32, 3) for t in 1:10 g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] From f1f258cd324fe6bf143ef74d703fdb829d90945f Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 23 Dec 2023 18:45:01 -0500 Subject: [PATCH 5/8] Swap Flux for Zygote in tests --- Project.toml | 4 ++-- test/runtests.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 0dd984a..c0528f3 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,8 @@ Optimisers = "0.3.1" julia = "1.6" [extras] -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [publish] ignore = ["^(gh-pages|juliamnt|julia.dmg)$"] @@ -22,4 +22,4 @@ theme = "_flux-theme" title = "ParameterSchedulers.jl" [targets] -test = ["Flux", "Test"] +test = ["Test", "Zygote"] diff --git a/test/runtests.jl b/test/runtests.jl index ba67fde..e62bcea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using ParameterSchedulers -using Flux +using Zygote using Optimisers using Test @@ -22,7 +22,7 @@ end o = Optimisers.setup(Scheduler(Optimisers.Descent, s), m) x = ones(Float32, 3) for t in 1:10 - g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] + g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1] o, m′ = Optimisers.update(o, m, g) @test m′.W ≈ m.W - g.W * s(t) @test m′.b ≈ m.b - g.b * s(t) @@ -36,7 +36,7 @@ end o = Optimisers.setup(Scheduler(Optimisers.Momentum, eta = seta, rho = srho), m) x = ones(Float32, 3) for t in 1:10 - g = Flux.gradient(m -> sum(m.W * x + m.b), m)[1] + g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1] o′, m′ = Optimisers.update(o, m, g) @test m′.W ≈ m.W - (srho(t) * o.W.state.opt + g.W * seta(t)) @test m′.b ≈ m.b - (srho(t) * o.b.state.opt + g.b * seta(t)) From 9f98bce2db317f8eeb0be058d58db29226b506ea Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 18:47:09 -0500 Subject: [PATCH 6/8] Add docs for new scheduler --- README.md | 2 +- docs/Project.toml | 1 + docs/src/tutorials/complex-schedules.md | 17 ++++--- docs/src/tutorials/optimizers.md | 61 +++++++++++++------------ src/scheduler.jl | 6 +-- 5 files changed, 48 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 02cd0b5..bd0744a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ParameterSchedulers.jl provides common machine learning (ML) schedulers for hype using Flux, ParameterSchedulers using ParameterSchedulers: Scheduler -opt = Scheduler(Exp(λ = 1e-2, γ = 0.8), Momentum()) +opt = Scheduler(Momentum, Exp(λ = 1e-2, γ = 0.8)) ``` ## Available Schedules diff --git a/docs/Project.toml b/docs/Project.toml index c7e134b..3a3cc1e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" diff --git a/docs/src/tutorials/complex-schedules.md b/docs/src/tutorials/complex-schedules.md index 163718b..ac35e35 100644 --- a/docs/src/tutorials/complex-schedules.md +++ b/docs/src/tutorials/complex-schedules.md @@ -78,19 +78,22 @@ Notice that our schedule changes around 1 second (half way through the simulatio For the second example, we'll look at a machine learning use-case. We want to write our schedule in terms of epochs, but our training loop iterates the scheduler every mini-batch. ```@example complex-schedules using Flux +using Optimisers using ParameterSchedulers: Scheduler nepochs = 3 -data = [(rand(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3] +data = [(Flux.rand32(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3] m = Chain(Dense(4, 4, tanh), Dense(4, 1, tanh)) -p = Flux.params(m) -s = Interpolator(Sequence(1e-2 => 1, Exp(1e-2, 2.0) => 2), length(data)) -opt = Scheduler(s, Descent()) +s = Interpolator(Sequence(1f-2 => 1, Exp(1f-2, 2f0) => 2), length(data)) +opt = Scheduler(Optimisers.Descent, s) +opt_st = Flux.setup(opt, m) for epoch in 1:nepochs for (i, (x, y)) in enumerate(data) - g = Flux.gradient(() -> Flux.mse(m(x), y), p) - Flux.update!(opt, p, g) - println("epoch: $epoch, batch: $i, η: $(opt.optim.eta)") + global opt_st, m + step = opt_st.layers[1].weight.state.t + println("epoch: $epoch, batch: $i, sched step = $step") + g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1] + opt_st, m = Flux.update!(opt_st, m, g) end end ``` diff --git a/docs/src/tutorials/optimizers.md b/docs/src/tutorials/optimizers.md index f5e51c5..ecac15a 100644 --- a/docs/src/tutorials/optimizers.md +++ b/docs/src/tutorials/optimizers.md @@ -7,18 +7,20 @@ A schedule by itself is not helpful; we need to use the schedules to adjust para Since every schedule is a standard iterator, we can insert it into a training loop by simply zipping up with another iterator. For example, the following code adjusts the learning rate of the optimizer before each batch of training. ```@example optimizers using Flux, ParameterSchedulers +using Optimisers: Descent, adjust! data = [(Flux.rand32(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3] m = Chain(Dense(4, 4, tanh), Dense(4, 1, tanh)) -p = Flux.params(m) opt = Descent() +opt_st = Flux.setup(opt, m) s = Exp(λ = 1e-1, γ = 0.2) -for (η, (x, y)) in zip(s, data) - opt.eta = η - g = Flux.gradient(() -> Flux.mse(m(x), y), p) - Flux.update!(opt, p, g) - println("η: ", opt.eta) +for (eta, (x, y)) in zip(s, data) + global opt_st, m + adjust!(opt_st, eta) + g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1] + opt_st, m = Flux.update!(opt_st, m, g) + println("opt state: ", opt_st.layers[1].weight.rule) end ``` @@ -26,12 +28,14 @@ We can also adjust the learning on an epoch basis instead. All that is required ```@example optimizers nepochs = 6 s = Step(λ = 1e-1, γ = 0.2, step_sizes = [3, 2, 1]) -for (η, epoch) in zip(s, 1:nepochs) - opt.eta = η +for (eta, epoch) in zip(s, 1:nepochs) + global opt_st + adjust!(opt_st, eta) for (i, (x, y)) in enumerate(data) - g = Flux.gradient(() -> Flux.mse(m(x), y), p) - Flux.update!(opt, p, g) - println("epoch: $epoch, batch: $i, η: $(opt.eta)") + global m + g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1] + opt_st, m = Flux.update!(opt_st, m, g) + println("epoch: $epoch, batch: $i, opt state: $(opt_st.layers[1].weight.rule)") end end ``` @@ -45,44 +49,45 @@ nepochs = 3 s = ParameterSchedulers.Stateful(Inv(λ = 1e-1, γ = 0.2, p = 2)) for epoch in 1:nepochs for (i, (x, y)) in enumerate(data) - opt.eta = ParameterSchedulers.next!(s) - g = Flux.gradient(() -> Flux.mse(m(x), y), p) - Flux.update!(opt, p, g) - println("epoch: $epoch, batch: $i, η: $(opt.eta)") + global opt_st, m + adjust!(opt_st, ParameterSchedulers.next!(s)) + g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1] + opt_st, m = Flux.update!(opt_st, m, g) + println("epoch: $epoch, batch: $i, opt state: $(opt_st.layers[1].weight.rule)") end end ``` ## Working with Flux optimizers -!!! warning - Currently, we are porting `Scheduler` to Flux.jl. - It may be renamed once it is ported out of this package. - The API will also undergo minor changes. - While the approaches above can be helpful when dealing with fine-grained training loops, it is usually simpler to just use a [`ParameterSchedulers.Scheduler`](@ref). ```@example optimizers using ParameterSchedulers: Scheduler nepochs = 3 s = Inv(λ = 1e-1, p = 2, γ = 0.2) -opt = Scheduler(s, Descent()) +opt = Scheduler(Descent, s) +opt_st = Flux.setup(opt, m) for epoch in 1:nepochs for (i, (x, y)) in enumerate(data) - g = Flux.gradient(() -> Flux.mse(m(x), y), p) - Flux.update!(opt, p, g) - println("epoch: $epoch, batch: $i, η: $(opt.optim.eta)") + global opt_st, m + sched_step = opt_st.layers[1].weight.state.t + println("epoch: $epoch, batch: $i, sched state: $sched_step") + g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1] + opt_st, m = Flux.update!(opt_st, m, g) end end ``` The scheduler, `opt`, can be used anywhere a Flux optimizer can. For example, it can be passed to `Flux.train!`: ```@example optimizers s = Inv(λ = 1e-1, p = 2, γ = 0.2) -opt = Scheduler(s, Descent()) -loss(x, y, m) = Flux.mse(m(x), y) -cb = () -> @show(opt.optim.eta) +opt = Scheduler(Descent, s) +opt_st = Flux.setup(opt, m) +loss(m, x, y) = Flux.mse(m(x), y) for epoch in 1:nepochs - Flux.train!((x, y) -> loss(x, y, m), Flux.params(m), data, opt, cb = cb) + sched_step = opt_st.layers[1].weight.state.t + println("epoch: $epoch, sched state: $sched_step") + Flux.train!(loss, m, data, opt_st) end ``` diff --git a/src/scheduler.jl b/src/scheduler.jl index 3493819..2375b31 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -4,9 +4,9 @@ Scheduler(constructor; field_a = schedule_a, field_b = schedule_b, ...) Wrap one or more schedules and optimizer together with a `Scheduler`. -On each call to [`Optimisers.apply!`](@ref Optimisers.apply!), the schedules -are iterated and `constructor` is used to invoke an optimization rule with -updated parameters. +On each call to [`Optimisers.apply!`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.apply!), +the schedules are iterated and `constructor` is used to invoke an +optimization rule with updated parameters. The `Scheduler` can be used anywhere an Optimisers.jl optimizer is used. If passed a single schedule and optimizer rule, the scheduler updates the From 57318b25ddfa2d11dd7a16c151756e30c7db43c3 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 18:50:44 -0500 Subject: [PATCH 7/8] Remove Revise from docs/make.jl --- docs/make.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index cd4b215..fbb767b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,3 @@ -using Revise using Documenter, ParameterSchedulers using Markdown From 112f577e5e5744bf1657fbad16f22ae8e90cc703 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 2 Feb 2024 18:57:00 -0500 Subject: [PATCH 8/8] Use newer Julia for docs --- .github/workflows/ci.yml | 2 +- docs/make.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 656e1ec..dce7a61 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,7 +58,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: '1.6' + version: '1' - run: | julia --project=docs -e ' using Pkg diff --git a/docs/make.jl b/docs/make.jl index fbb767b..cd4b215 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,4 @@ +using Revise using Documenter, ParameterSchedulers using Markdown