Skip to content

Commit 8e42252

Browse files
committed
feat(analyzer): support simple type inference
1 parent adfc914 commit 8e42252

6 files changed

Lines changed: 102 additions & 72 deletions

File tree

src/commonMain/kotlin/io/github/dingyi222666/luaparser/lexer/LuaLexer.kt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,22 @@ class LuaLexer(
148148

149149
ch == '.' -> {
150150
val next = chatAtOrNull() ?: return LuaTokenTypes.DOT
151-
if (isPrimeDigit(next)) {
152-
scanPrimeDigit()
153-
LuaTokenTypes.NUMBER
154-
} else LuaTokenTypes.DOT
151+
152+
when {
153+
isPrimeDigit(next) -> {
154+
scanPrimeDigit()
155+
LuaTokenTypes.NUMBER
156+
}
157+
158+
next == '.' -> {
159+
tokenLength++
160+
LuaTokenTypes.CONCAT
161+
}
162+
163+
else -> LuaTokenTypes.DOT
164+
}
165+
166+
155167
}
156168

157169
ch == '"' || ch == '\'' -> scanString(ch)

src/commonMain/kotlin/io/github/dingyi222666/luaparser/semantic/SemanticAnalyzer.kt

Lines changed: 54 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
1717
private val globalScope = GlobalScope(range = Range.EMPTY)
1818

1919
private fun createLocalScope(node: BaseASTNode): LocalScope {
20-
val parent = scopeStack.first()
20+
val parent = currentScope
2121
val localScope = LocalScope(parent, node.range)
2222
scopeStack.addFirst(localScope)
2323
globalScope.addScope(localScope)
2424
return localScope
2525
}
2626

2727
private fun createFunctionScope(node: BaseASTNode): FunctionScope {
28-
val parent = scopeStack.first()
28+
val parent = currentScope
2929
val localScope = FunctionScope(parent, node.range)
3030
scopeStack.addFirst(localScope)
3131
globalScope.addScope(localScope)
@@ -34,7 +34,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
3434

3535

3636
private fun createLoopScope(node: BaseASTNode): LoopScope {
37-
val parent = scopeStack.first()
37+
val parent = currentScope
3838
val loopScope = LoopScope(parent, node.range)
3939
scopeStack.addFirst(loopScope)
4040
globalScope.addScope(loopScope)
@@ -99,21 +99,21 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
9999
is TableCallExpression -> return visitTableCallExpression(node, value)
100100
}
101101

102-
val currentScope = scopeStack.first()
102+
val type = getCallExpressionType(node, currentScope)
103+
103104
when (value) {
104105
is Identifier -> setIdentifierType(
105-
value, getCallExpressionType(node, currentScope), currentScope
106+
value, type, currentScope
106107
)
107108

108109
is MemberExpression -> setMemberExpressionType(
109-
value, getCallExpressionType(node, currentScope), currentScope
110+
value, type, currentScope
110111
)
111112
}
112113
}
113114

114115

115116
override fun visitConstantNode(node: ConstantNode, value: BaseASTNode) {
116-
val currentScope = scopeStack.first()
117117
when (value) {
118118
is Identifier -> setIdentifierType(value, node.asType(), currentScope)
119119

@@ -130,21 +130,18 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
130130
}
131131

132132
override fun visitBreakStatement(node: BreakStatement, value: BaseASTNode) {
133-
val currentScope = scopeStack.first()
134133
if (currentScope !is LoopScope) {
135134
error(node) { "no loop to break" }
136135
}
137136
}
138137

139138
override fun visitContinueStatement(node: ContinueStatement, value: BaseASTNode) {
140-
val currentScope = scopeStack.first()
141139
if (currentScope !is LoopScope) {
142140
error(node) { "no loop to continue" }
143141
}
144142
}
145143

146144
override fun visitBinaryExpression(node: BinaryExpression, value: BaseASTNode) {
147-
val currentScope = scopeStack.first()
148145
when (value) {
149146
is Identifier -> setIdentifierType(
150147
value,
@@ -168,14 +165,12 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
168165
val symbol = createStatementSymbol(
169166
"return", node, tupleType
170167
)
171-
val currentScope = scopeStack.first()
168+
172169
currentScope.addSymbol(symbol)
173170
super.visitReturnStatement(node, node)
174171
}
175172

176173
override fun visitMemberExpression(node: MemberExpression, value: BaseASTNode) {
177-
val currentScope = scopeStack.first()
178-
179174
val currentType = resolveMemberExpressionType(node, currentScope)
180175
when (value) {
181176
is Identifier -> setIdentifierType(value, currentType, currentScope)
@@ -187,14 +182,12 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
187182
}
188183

189184
override fun visitIdentifier(node: Identifier, value: BaseASTNode) {
190-
val currentScope = scopeStack.first()
185+
191186
when (value) {
192187
// params: function(a)
193188
is FunctionDeclaration -> {
194189
val parameterSymbol = createParamsVariable(node, currentScope)
195-
if (node.name == "self" && value.params.indexOf(node) == 0 &&
196-
value.identifier is MemberExpression
197-
) {
190+
if (node.name == "self" && value.params.indexOf(node) == 0 && value.identifier is MemberExpression) {
198191
setSelfType(value, parameterSymbol, currentScope)
199192
}
200193
}
@@ -218,17 +211,12 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
218211

219212
// return: return a
220213
is ReturnStatement -> setReturnStatementType(
221-
value,
222-
node,
223-
resolveExpressionNodeType(node) ?: Type.ANY,
224-
currentScope
214+
value, node, resolveExpressionNodeType(node) ?: Type.ANY, currentScope
225215
)
226216
}
227217
}
228218

229219
override fun visitTableConstructorExpression(node: TableConstructorExpression, value: BaseASTNode) {
230-
val currentScope = scopeStack.first()
231-
232220
// make parent as table
233221
super.visitTableConstructorExpression(node, node)
234222

@@ -244,18 +232,13 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
244232
)
245233

246234
is ReturnStatement -> setReturnStatementType(
247-
value,
248-
node,
249-
getTableConstructorExpressionType(node),
250-
currentScope
235+
value, node, getTableConstructorExpressionType(node), currentScope
251236
)
252237
}
253238
}
254239

255240

256241
override fun visitAssignmentStatement(node: AssignmentStatement, value: BaseASTNode) {
257-
val currentScope = scopeStack.first()
258-
259242
val initSymbols = node.init.map { initNode ->
260243
when (initNode) {
261244
is Identifier -> {
@@ -341,14 +324,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
341324
node.params.forEach {
342325
funcScope.resolveSymbol(it.name, it.range.start)?.let { paramSymbol ->
343326
val paramType = paramSymbol.type
344-
if (paramType.typeVariableName == "self" && !node.isLocal && node.identifier is MemberExpression && paramType is ParameterType) {
345-
val list = transformMemberExpressionToList(node.identifier as MemberExpression)
346-
paramType.isSelf = true
347-
// set to self
348-
paramType.realType =
349-
funcScope.resolveSymbol(list.first().name, list.first().range.start)?.type ?: paramType.realType
350-
}
351-
funcType.addParamType(paramSymbol.type)
327+
funcType.addParamType(paramType)
352328
}
353329
}
354330

@@ -389,7 +365,6 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
389365
private fun visitGlobalFunctionDeclaration(
390366
node: FunctionDeclaration, value: BaseASTNode, functionType: FunctionType
391367
) {
392-
val currentScope = scopeStack.first()
393368
val identifier = node.identifier
394369
val variable: String
395370

@@ -433,7 +408,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
433408
private fun visitLocalFunctionDeclaration(
434409
node: FunctionDeclaration, functionType: FunctionType
435410
) {
436-
val currentScope = scopeStack.first()
411+
437412

438413
// local function
439414
if (node.identifier is Identifier && node.isLocal) {
@@ -463,6 +438,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
463438
private fun getTableConstructorExpressionType(
464439
node: TableConstructorExpression, name: String = "anonymous"
465440
): Type {
441+
val scope = currentScope
466442
val rootType = TableType(TypeKind.Table, name)
467443
val tableConstructorStack = ArrayDeque<Pair<TableConstructorExpression, TableType>>()
468444
var currentType: TableType
@@ -497,10 +473,8 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
497473
tableConstructorStack.addLast(value to valueType)
498474
currentType = valueType
499475
} else {
500-
val valueType = resolveExpressionNodeType(value)
501-
if (valueType is Type) {
502-
currentType.setMember(keyValue.toString(), keyType ?: Type.ANY, valueType)
503-
}
476+
val valueType = resolveExpressionNodeType(value, scope)
477+
currentType.setMember(keyValue.toString(), keyType ?: Type.ANY, valueType)
504478
}
505479

506480
}
@@ -524,38 +498,52 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
524498
base.indexer == ":"
525499
} else false
526500

527-
if (baseType is FunctionType) {
528-
val params = baseType.parameterTypes
501+
502+
if (baseType is FunctionType || ((baseType is UnionType) && baseType.types.any { it is FunctionType })) {
503+
val funcType =
504+
if (baseType is FunctionType) baseType else (baseType as UnionType).types.filterIsInstance<FunctionType>()
505+
.getOrNull(0) ?: return Type.ANY
506+
val params = funcType.parameterTypes
529507
val paramsSize = params.size
530508

531-
if (!callSelf && baseType.isSelf && args.isEmpty()) {
509+
510+
if (!callSelf && funcType.isSelf && args.isEmpty()) {
532511
println("need add : to call with self")
533512
}
534513

535514
for (i in 0..<paramsSize) {
536515
val paramType = params[i]
537516
val argNode = args.getOrNull(i) ?: break
538-
val argType = resolveExpressionNodeType(argNode) ?: Type.ANY
517+
val argType = resolveExpressionNodeType(argNode)
539518

540519
if (paramType is ParameterType) {
541-
paramType.realType = argType
520+
paramType.realType = paramType.realType.union(argType)
521+
} else {
522+
funcType.setParamType(i, paramType.union(argType))
542523
}
543524
}
544525

545-
return createTupleType(baseType.returnTypes)
526+
return createTupleType(funcType.returnTypes)
546527
}
547528

548-
// what is it
549-
if (baseType is ParameterType) {
550-
val argsSize = args.size
551-
if (argsSize != 2) {
552-
// ???
553-
}
554-
val argType = resolveExpressionNodeType(args[0])
555-
if (argType is Type) {
556-
baseType.realType = argType
529+
if (base is Identifier && (baseType is UnDefinedType || baseType is AnyType) && args.isNotEmpty()) {
530+
val symbol = currentScope.resolveSymbol(base.name, base.range.start) ?: return Type.ANY
531+
532+
val functionType = LikeFunctionType()
533+
534+
functionType.addReturnType(Type.ANY)
535+
536+
for (argIndex in 0..<args.size) {
537+
val argType = resolveExpressionNodeType(args[argIndex])
538+
val rawParamType = functionType.getParamTypeOrNull(argIndex)
539+
val paramType = (rawParamType ?: argType).union(argType)
540+
if (rawParamType == null) functionType.addParamType(paramType)
541+
else functionType.setParamType(argIndex, paramType)
557542
}
558543

544+
symbol.type = symbol.type.union(Type.ANY).union(functionType)
545+
546+
return createTupleType(functionType.returnTypes)
559547
}
560548

561549
return Type.ANY
@@ -583,8 +571,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
583571
}
584572

585573
private fun resolveMemberExpressionType(
586-
node: MemberExpression,
587-
currentScope: Scope
574+
node: MemberExpression, currentScope: Scope
588575
): Type {
589576
val list = transformMemberExpressionToList(node)
590577
var last = list.removeFirst()
@@ -596,8 +583,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
596583

597584
val lastType = currentType
598585
currentType = when (currentType) {
599-
is TableType ->
600-
currentType.searchMember(last.name)
586+
is TableType -> currentType.searchMember(last.name)
601587

602588
else -> null
603589
}
@@ -629,14 +615,14 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
629615
return unpackType(currentType ?: Type.ANY)
630616
}
631617

632-
private fun resolveExpressionNodeType(node: ExpressionNode, scope: Scope = globalScope): Type? {
618+
private fun resolveExpressionNodeType(node: ExpressionNode, scope: Scope = currentScope): Type {
633619
val resultType = when (node) {
634620
is ConstantNode -> node.asType()
635621
is TableConstructorExpression -> getTableConstructorExpressionType(node)
636622

637623
is Identifier -> {
638624
val symbol = scope.resolveSymbol(node.name, node.range.start)
639-
// 因为也不知道是什么,那直接创建一个新的全局变量
625+
// unknown symbol, create a global symbol
640626
symbol?.type ?: createGlobalVariable(node).type
641627
}
642628

@@ -834,7 +820,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
834820
}
835821

836822
private fun createLocalVariable(identifier: Identifier): VariableSymbol {
837-
val currentScope = scopeStack.first()
823+
838824

839825
val symbol = createVariableSymbol(identifier, currentScope)
840826
currentScope.addSymbol(symbol)
@@ -867,4 +853,7 @@ class SemanticAnalyzer : ASTVisitor<BaseASTNode> {
867853
currentScope.addSymbol(symbol)
868854
return symbol
869855
}
856+
857+
private val currentScope
858+
get() = scopeStack.first()
870859
}

0 commit comments

Comments
 (0)