Skip to content

Commit

Permalink
Merge pull request #2822 from tomlau10/fix/type_narrow
Browse files Browse the repository at this point in the history
fix: improve function type narrow by checking params' literal identical
  • Loading branch information
sumneko authored Sep 5, 2024
2 parents 08dd0ca + 30deedc commit c636fdd
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 13 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
* `FIX` Improve type narrow by checking exact match on literal type params

## 3.10.5
`2024-8-19`
Expand Down
59 changes: 48 additions & 11 deletions script/vm/function.lua
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,35 @@ local function isAllParamMatched(uri, args, params)
return true
end

---@param uri uri
---@param args parser.object[]
---@param func parser.object
---@return number
local function calcFunctionMatchScore(uri, args, func)
if vm.isVarargFunctionWithOverloads(func)
or not isAllParamMatched(uri, args, func.args)
then
return -1
end
local matchScore = 0
for i = 1, math.min(#args, #func.args) do
local arg, param = args[i], func.args[i]
local defLiterals, literalsCount = vm.getLiterals(param)
if defLiterals then
for n in vm.compileNode(arg):eachObject() do
-- if param's literals map contains arg's literal, this is narrower than a subtype match
if defLiterals[guide.getLiteral(n)] then
-- the more the literals defined in the param, the less bonus score will be added
-- this favors matching overload param with exact literal value, over alias/enum that has many literal values
matchScore = matchScore + 1/literalsCount
break
end
end
end
end
return matchScore
end

---@param func parser.object
---@param args? parser.object[]
---@return parser.object[]?
Expand All @@ -365,21 +394,29 @@ function vm.getExactMatchedFunctions(func, args)
return funcs
end
local uri = guide.getUri(func)
local needRemove
local matchScores = {}
for i, n in ipairs(funcs) do
if vm.isVarargFunctionWithOverloads(n)
or not isAllParamMatched(uri, args, n.args) then
if not needRemove then
needRemove = {}
end
needRemove[#needRemove+1] = i
end
matchScores[i] = calcFunctionMatchScore(uri, args, n)
end

local maxMatchScore = math.max(table.unpack(matchScores))
if maxMatchScore == -1 then
-- all should be removed
return nil
end
if not needRemove then

local minMatchScore = math.min(table.unpack(matchScores))
if minMatchScore == maxMatchScore then
-- all should be kept
return funcs
end
if #needRemove == #funcs then
return nil

-- remove functions that have matchScore < maxMatchScore
local needRemove = {}
for i, matchScore in ipairs(matchScores) do
if matchScore < maxMatchScore then
needRemove[#needRemove + 1] = i
end
end
util.tableMultiRemove(funcs, needRemove)
return funcs
Expand Down
7 changes: 5 additions & 2 deletions script/vm/value.lua
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,13 @@ end

---@param v vm.object
---@return table<any, boolean>?
---@return integer
function vm.getLiterals(v)
if not v then
return nil
return nil, 0
end
local map
local count = 0
local node = vm.compileNode(v)
for n in node:eachObject() do
local literal
Expand All @@ -237,7 +239,8 @@ function vm.getLiterals(v)
map = {}
end
map[literal] = true
count = count + 1
end
end
return map
return map, count
end
34 changes: 34 additions & 0 deletions test/type_inference/param_match.lua
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,40 @@ local function f(...) end
local <?r?> = f(10)
]]

TEST '1' [[
---@overload fun(a: string): 1
---@overload fun(a: 'y'): 2
local function f(...) end
local <?r?> = f('x')
]]

TEST '2' [[
---@overload fun(a: string): 1
---@overload fun(a: 'y'): 2
local function f(...) end
local <?r?> = f('y')
]]

TEST '1' [[
---@overload fun(a: string): 1
---@overload fun(a: 'y'): 2
local function f(...) end
local v = 'x'
local <?r?> = f(v)
]]

TEST '2' [[
---@overload fun(a: string): 1
---@overload fun(a: 'y'): 2
local function f(...) end
local v = 'y'
local <?r?> = f(v)
]]

TEST 'number' [[
---@overload fun(a: 1, c: fun(x: number))
---@overload fun(a: 2, c: fun(x: string))
Expand Down

0 comments on commit c636fdd

Please sign in to comment.