From e09d181224da52fd6f82d228a8bcdffba58564d8 Mon Sep 17 00:00:00 2001 From: Tom Lau Date: Fri, 23 Aug 2024 09:54:24 +0800 Subject: [PATCH 1/2] fix: improve function type narrow by checking params' literal identical --- changelog.md | 1 + script/vm/function.lua | 59 ++++++++++++++++++++++++++++++++++-------- script/vm/value.lua | 7 +++-- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/changelog.md b/changelog.md index 2f0ea8f72..2a35ff3d2 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +* `FIX` Improve type narrow by checking exact match on literal type params ## 3.10.5 `2024-8-19` diff --git a/script/vm/function.lua b/script/vm/function.lua index 1e3083172..7a15ac5a7 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -353,6 +353,35 @@ local function isAllParamMatched(uri, args, params) return true end +---@param uri uri +---@param args parser.object[] +---@param func parser.object +---@return number +local function calcFunctionMatchScore(uri, args, func) + if vm.isVarargFunctionWithOverloads(func) + or not isAllParamMatched(uri, args, func.args) + then + return -1 + end + local matchScore = 0 + for i = 1, math.min(#args, #func.args) do + local arg, param = args[i], func.args[i] + local defLiterals, literalsCount = vm.getLiterals(param) + if defLiterals then + for n in vm.compileNode(arg):eachObject() do + -- if param's literals map contains arg's literal, this is narrower than a subtype match + if defLiterals[guide.getLiteral(n)] then + -- the more the literals defined in the param, the less bonus score will be added + -- this favors matching overload param with exact literal value, over alias/enum that has many literal values + matchScore = matchScore + 1/literalsCount + break + end + end + end + end + return matchScore +end + ---@param func parser.object ---@param args? parser.object[] ---@return parser.object[]? @@ -365,21 +394,29 @@ function vm.getExactMatchedFunctions(func, args) return funcs end local uri = guide.getUri(func) - local needRemove + local matchScores = {} for i, n in ipairs(funcs) do - if vm.isVarargFunctionWithOverloads(n) - or not isAllParamMatched(uri, args, n.args) then - if not needRemove then - needRemove = {} - end - needRemove[#needRemove+1] = i - end + matchScores[i] = calcFunctionMatchScore(uri, args, n) + end + + local maxMatchScore = math.max(table.unpack(matchScores)) + if maxMatchScore == -1 then + -- all should be removed + return nil end - if not needRemove then + + local minMatchScore = math.min(table.unpack(matchScores)) + if minMatchScore == maxMatchScore then + -- all should be kept return funcs end - if #needRemove == #funcs then - return nil + + -- remove functions that have matchScore < maxMatchScore + local needRemove = {} + for i, matchScore in ipairs(matchScores) do + if matchScore < maxMatchScore then + needRemove[#needRemove + 1] = i + end end util.tableMultiRemove(funcs, needRemove) return funcs diff --git a/script/vm/value.lua b/script/vm/value.lua index 7eab4a8e5..ce031357d 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -213,11 +213,13 @@ end ---@param v vm.object ---@return table? +---@return integer function vm.getLiterals(v) if not v then - return nil + return nil, 0 end local map + local count = 0 local node = vm.compileNode(v) for n in node:eachObject() do local literal @@ -237,7 +239,8 @@ function vm.getLiterals(v) map = {} end map[literal] = true + count = count + 1 end end - return map + return map, count end From cd5ebb588118efa98dfe90f3ea485c00ea73c175 Mon Sep 17 00:00:00 2001 From: Tom Lau Date: Tue, 27 Aug 2024 14:22:19 +0800 Subject: [PATCH 2/2] test: add tests for improved function type narrow --- test/type_inference/param_match.lua | 34 +++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/type_inference/param_match.lua b/test/type_inference/param_match.lua index 1079e4332..906b93059 100644 --- a/test/type_inference/param_match.lua +++ b/test/type_inference/param_match.lua @@ -138,6 +138,40 @@ local function f(...) end local = f(10) ]] +TEST '1' [[ +---@overload fun(a: string): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local = f('x') +]] + +TEST '2' [[ +---@overload fun(a: string): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local = f('y') +]] + +TEST '1' [[ +---@overload fun(a: string): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local v = 'x' +local = f(v) +]] + +TEST '2' [[ +---@overload fun(a: string): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local v = 'y' +local = f(v) +]] + TEST 'number' [[ ---@overload fun(a: 1, c: fun(x: number)) ---@overload fun(a: 2, c: fun(x: string))