Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add infer function param type #2532

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions locale/en-us/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ When checking the type of union type, ignore the `nil` in it.

When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType =
[[
When a parameter type is not annotated, it is inferred from the function's call sites.

When this setting is `false`, the type of the parameter is `any` when it is not annotated.
]]
config.doc.privateName =
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName =
Expand Down
6 changes: 6 additions & 0 deletions locale/pt-br/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ When checking the type of union type, ignore the `nil` in it.

When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType = -- TODO: need translate!
[[
When the parameter type is not annotated, the parameter type is inferred from the function's incoming parameters.

When this setting is `false`, the type of the parameter is `any` when it is not annotated.
]]
config.doc.privateName = -- TODO: need translate!
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName = -- TODO: need translate!
Expand Down
6 changes: 6 additions & 0 deletions locale/zh-cn/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ config.type.weakNilCheck =

此设置为 `false` 时,`numer|nil` 类型无法赋给 `number` 类型;为 `true` 是则可以。
]]
config.type.inferParamType =
[[
未注释参数类型时,参数类型由函数传入参数推断。

如果设置为 "false",则在未注释时,参数类型为 "any"。
]]
config.doc.privateName =
'将特定名称的字段视为私有,例如 `m_*` 意味着 `XXX.m_id` 与 `XXX.m_type` 是私有字段,只能在定义所在的类中访问。'
config.doc.protectedName =
Expand Down
6 changes: 6 additions & 0 deletions locale/zh-tw/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ When checking the type of union type, ignore the `nil` in it.

When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType = -- TODO: need translate!
[[
未注释参数类型时,参数类型由函数传入参数推断。

如果设置为 "false",则在未注释时,参数类型为 "any"。
]]
config.doc.privateName = -- TODO: need translate!
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName = -- TODO: need translate!
Expand Down
4 changes: 2 additions & 2 deletions script/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ local function searchPatchInfo(cfg, rawKey)
}
end

---@param uri uri
---@param uri? uri
---@param cfg table
---@param change config.change
---@return json.patch?
Expand Down Expand Up @@ -330,7 +330,7 @@ local function makeConfigPatch(uri, cfg, change)
return nil
end

---@param uri uri
---@param uri? uri
---@param path string
---@param changes config.change[]
---@return string?
Expand Down
1 change: 1 addition & 0 deletions script/config/template.lua
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ local template = {
['Lua.type.castNumberToInteger'] = Type.Boolean >> true,
['Lua.type.weakUnionCheck'] = Type.Boolean >> false,
['Lua.type.weakNilCheck'] = Type.Boolean >> false,
['Lua.type.inferParamType'] = Type.Boolean >> false,
['Lua.doc.privateName'] = Type.Array(Type.String),
['Lua.doc.protectedName'] = Type.Array(Type.String),
['Lua.doc.packageName'] = Type.Array(Type.String),
Expand Down
5 changes: 4 additions & 1 deletion script/core/command/autoRequire.lua
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ end

---@async
return function (data)
---@type uri
local uri = data.uri
local target = data.target
local name = data.name
Expand All @@ -158,5 +159,7 @@ return function (data)
end

local offset, fmt = findInsertRow(uri)
applyAutoRequire(uri, offset, name, requireName, fmt)
if offset and fmt then
applyAutoRequire(uri, offset, name, requireName, fmt)
end
end
52 changes: 37 additions & 15 deletions script/core/completion/completion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ end

local function findParent(state, position)
local text = state.lua
if not text then
return
end
local offset = guide.positionToOffset(state, position)
for i = offset, 1, -1 do
local char = text:sub(i, i)
Expand Down Expand Up @@ -675,6 +678,7 @@ local function checkGlobal(state, word, startPos, position, parent, oop, results
end

---@async
---@param parent parser.object
local function checkField(state, word, start, position, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
local globals = vm.getGlobalSets(state.uri, 'variable')
Expand Down Expand Up @@ -955,8 +959,7 @@ local function checkFunctionArgByDocParam(state, word, startPos, results)
end
end

local function isAfterLocal(state, startPos)
local text = state.lua
local function isAfterLocal(state, text, startPos)
local offset = guide.positionToOffset(state, startPos)
local pos = lookBackward.skipSpace(text, offset)
local word = lookBackward.findWord(text, pos)
Expand All @@ -965,6 +968,8 @@ end

local function collectRequireNames(mode, myUri, literal, source, smark, position, results)
local collect = {}
local source_start = source and smark and (source.start + #smark) or position
local source_finish = source and smark and (source.finish - #smark) or position
if mode == 'require' then
for uri in files.eachFile(myUri) do
if myUri == uri then
Expand All @@ -978,8 +983,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[info.name] then
collect[info.name] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and info.name or util.viewString(info.name),
},
path = relative,
Expand All @@ -1006,8 +1011,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[open] then
collect[open] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and open or util.viewString(open),
},
path = path,
Expand All @@ -1034,8 +1039,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[path] then
collect[path] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and path or util.viewString(path),
}
}
Expand Down Expand Up @@ -1097,6 +1102,9 @@ end

local function checkLenPlusOne(state, position, results)
local text = state.lua
if not text then
return
end
guide.eachSourceContain(state.ast, position, function (source)
if source.type == 'getindex'
or source.type == 'setindex' then
Expand Down Expand Up @@ -1392,6 +1400,9 @@ end

local function checkEqualEnum(state, position, results)
local text = state.lua
if not text then
return
end
local start = lookBackward.findTargetSymbol(text, guide.positionToOffset(state, position), '=')
if not start then
return
Expand Down Expand Up @@ -1493,6 +1504,9 @@ local function tryWord(state, position, triggerCharacter, results)
return
end
local text = state.lua
if not text then
return
end
local offset = guide.positionToOffset(state, position)
local finish = lookBackward.skipSpace(text, offset)
local word, start = lookBackward.findWord(text, offset)
Expand All @@ -1518,7 +1532,7 @@ local function tryWord(state, position, triggerCharacter, results)
checkProvideLocal(state, word, startPos, results)
checkFunctionArgByDocParam(state, word, startPos, results)
else
local afterLocal = isAfterLocal(state, startPos)
local afterLocal = isAfterLocal(state, text, startPos)
local stop = checkKeyWord(state, startPos, position, word, hasSpace, afterLocal, results)
if stop then
return
Expand All @@ -1530,8 +1544,10 @@ local function tryWord(state, position, triggerCharacter, results)
checkLocal(state, word, startPos, results)
checkTableField(state, word, startPos, results)
local env = guide.getENV(state.ast, startPos)
checkGlobal(state, word, startPos, position, env, false, results)
checkModule(state, word, startPos, results)
if env then
checkGlobal(state, word, startPos, position, env, false, results)
checkModule(state, word, startPos, results)
end
end
end
end
Expand Down Expand Up @@ -1592,6 +1608,9 @@ end

local function checkTableLiteralField(state, position, tbl, fields, results)
local text = state.lua
if not text then
return
end
local mark = {}
for _, field in ipairs(tbl) do
if field.type == 'tablefield'
Expand All @@ -1610,9 +1629,11 @@ local function checkTableLiteralField(state, position, tbl, fields, results)
local left = lookBackward.findWord(text, guide.positionToOffset(state, position))
if not left then
local pos = lookBackward.findAnyOffset(text, guide.positionToOffset(state, position))
local char = text:sub(pos, pos)
if char == '{' or char == ',' or char == ';' then
left = ''
if pos then
local char = text:sub(pos, pos)
if char == '{' or char == ',' or char == ';' then
left = ''
end
end
end
if left then
Expand Down Expand Up @@ -1801,6 +1822,7 @@ local function getluaDocByContain(state, position)
return result
end

---@return parser.state.err?, parser.object?
local function getluaDocByErr(state, start, position)
local targetError
for _, err in ipairs(state.errs) do
Expand Down Expand Up @@ -2008,7 +2030,7 @@ local function tryluaDocByErr(state, position, err, docState, results)
for _, doc in ipairs(vm.getDocSets(state.uri)) do
if doc.type == 'doc.class'
and not used[doc.class[1]]
and doc.class[1] ~= docState.class[1] then
and docState and doc.class[1] ~= docState.class[1] then
used[doc.class[1]] = true
results[#results+1] = {
label = doc.class[1],
Expand Down
13 changes: 1 addition & 12 deletions script/core/diagnostics/undefined-doc-name.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@ return function (uri, callback)
return
end

local function hasNameOfGeneric(name, source)
if not source.typeGeneric then
return false
end
if not source.typeGeneric[name] then
return false
end
return true
end

guide.eachSource(state.ast.docs, function (source)
if source.type ~= 'doc.extends.name'
and source.type ~= 'doc.type.name' then
Expand All @@ -35,8 +25,7 @@ return function (uri, callback)
if name == '...' or name == '_' or name == 'self' then
return
end
if #vm.getDocSets(uri, name) > 0
or hasNameOfGeneric(name, source) then
if #vm.getDocSets(uri, name) > 0 then
return
end
callback {
Expand Down
4 changes: 2 additions & 2 deletions script/core/highlight.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ local function checkInIf(state, source, text, position)
local endA = endB - #'end' + 1
if position >= source.finish - #'end'
and position <= source.finish
and text:sub(endA, endB) == 'end' then
and text and text:sub(endA, endB) == 'end' then
return true
end
-- 检查每个子模块
Expand All @@ -83,7 +83,7 @@ local function makeIf(state, source, text, callback)
-- end
local endB = guide.positionToOffset(state, source.finish)
local endA = endB - #'end' + 1
if text:sub(endA, endB) == 'end' then
if text and text:sub(endA, endB) == 'end' then
callback(source.finish - #'end', source.finish)
end
-- 每个子模块
Expand Down
6 changes: 6 additions & 0 deletions script/fs-utility.lua
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ function dfs:__div(filename)
return new
end

---@package
function dfs:_open(index)
local paths = split(self.path, '[/\\]')
local current = self.files
Expand All @@ -147,6 +148,7 @@ function dfs:_open(index)
return current
end

---@package
function dfs:_filename()
return self.path:match '[^/\\]+$'
end
Expand Down Expand Up @@ -291,6 +293,7 @@ local function fsIsDirectory(path, option)
if path.type == 'dummy' then
return path:isDirectory()
end
---@cast path -dummyfs
local status = fs.symlink_status(path):type()
return status == 'directory'
end
Expand Down Expand Up @@ -347,6 +350,7 @@ local function fsSave(path, text, option)
return false
end
if path.type == 'dummy' then
---@cast path -fs.path
local dir = path:_open(-2)
if not dir then
option.err[#option.err+1] = '无法打开:' .. path:string()
Expand Down Expand Up @@ -385,6 +389,7 @@ local function fsLoad(path, option)
return nil
end
else
---@cast path -dummyfs
local text, err = m.loadFile(path)
if text then
return text
Expand All @@ -407,6 +412,7 @@ local function fsCopy(source, target, option)
end
return fsSave(target, sourceText, option)
else
---@cast source -dummyfs
if target.type == 'dummy' then
local sourceText, err = m.loadFile(source)
if not sourceText then
Expand Down
3 changes: 2 additions & 1 deletion script/gc.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
local util = require 'utility'

---@class gc
---@field _list table
---@field package _list table
local mt = {}
mt.__index = mt
mt.type = 'gc'
mt._removed = false

---@package
mt._max = 10

local function destroyGCObject(obj)
Expand Down
Loading
Loading