Skip to content

Commit 7184d3f

Browse files
Saloedlehvolk
authored andcommitted
Fix lambda expressions
1 parent 4ea9de3 commit 7184d3f

5 files changed

Lines changed: 208 additions & 22 deletions

File tree

jacodb-api/src/main/kotlin/org/jacodb/api/cfg/JcInst.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,18 @@ data class JcPhiExpr(
670670
* object, but stores a reference to the actual method
671671
*/
672672
data class JcLambdaExpr(
673-
private val methodRef: TypedMethodRef,
674-
override val args: List<JcValue>,
673+
private val bsmRef: TypedMethodRef,
674+
val actualMethod: TypedMethodRef,
675+
val interfaceMethodType: BsmMethodTypeArg,
676+
val dynamicMethodType: BsmMethodTypeArg,
677+
val callSiteMethodName: String,
678+
val callSiteArgTypes: List<JcType>,
679+
val callSiteReturnType: JcType,
680+
val callSiteArgs: List<JcValue>
675681
) : JcCallExpr {
676682

677-
override val method: JcTypedMethod get() = methodRef.method
683+
override val method get() = bsmRef.method
684+
override val args get() = callSiteArgs
678685

679686
override fun <T> accept(visitor: JcExprVisitor<T>): T {
680687
return visitor.visitJcLambdaExpr(this)

jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/JcInstListBuilder.kt

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package org.jacodb.impl.cfg
1919
import org.jacodb.api.*
2020
import org.jacodb.api.cfg.*
2121
import org.jacodb.api.ext.*
22+
import org.jacodb.impl.cfg.util.LAMBDA_METAFACTORY_CLASS
2223
import org.jacodb.impl.cfg.util.UNINIT_THIS
24+
import org.jacodb.impl.cfg.util.typeName
2325

2426
/** This class stores state and is NOT THREAD SAFE. Use it carefully */
2527
class JcInstListBuilder(val method: JcMethod,val instList: JcInstList<JcRawInst>) : JcRawInstVisitor<JcInst?>, JcRawExprVisitor<JcExpr> {
@@ -259,29 +261,53 @@ class JcInstListBuilder(val method: JcMethod,val instList: JcInstList<JcRawInst>
259261
override fun visitJcRawInstanceOfExpr(expr: JcRawInstanceOfExpr): JcExpr =
260262
JcInstanceOfExpr(classpath.boolean, expr.operand.accept(this) as JcValue, expr.targetType.asType())
261263

264+
private val lambdaMetaFactory: TypeName by lazy { LAMBDA_METAFACTORY_CLASS.typeName() }
265+
private val lambdaMetaFactoryMethodName: String = "metafactory"
266+
262267
override fun visitJcRawDynamicCallExpr(expr: JcRawDynamicCallExpr): JcExpr {
263-
val lambdaBases = expr.bsmArgs.filterIsInstance<BsmHandle>()
264-
when (lambdaBases.size) {
265-
1 -> {
266-
val base = lambdaBases.first()
267-
val klass = base.declaringClass.asType() as JcClassType
268-
val ref = TypedMethodRefImpl(klass, base.name, base.argTypes, base.returnType)
269-
270-
return JcLambdaExpr(ref, expr.args.map { it.accept(this) as JcValue })
271-
}
268+
if (expr.bsm.declaringClass == lambdaMetaFactory && expr.bsm.name == lambdaMetaFactoryMethodName) {
269+
val lambdaExpr = tryResolveJcLambdaExpr(expr)
270+
if (lambdaExpr != null) return lambdaExpr
271+
}
272272

273-
else -> {
273+
return JcDynamicCallExpr(
274+
classpath.methodRef(expr),
275+
expr.bsmArgs,
276+
expr.callSiteMethodName,
277+
expr.callSiteArgTypes.map { it.asType() },
278+
expr.callSiteReturnType.asType(),
279+
expr.callSiteArgs.map { it.accept(this) as JcValue }
280+
)
281+
}
274282

275-
return JcDynamicCallExpr(
276-
classpath.methodRef(expr),
277-
expr.bsmArgs,
278-
expr.callSiteMethodName,
279-
expr.callSiteArgTypes.map { it.asType() },
280-
expr.callSiteReturnType.asType(),
281-
expr.args.map { it.accept(this) as JcValue }
282-
)
283-
}
283+
private fun tryResolveJcLambdaExpr(expr: JcRawDynamicCallExpr): JcLambdaExpr? {
284+
if (expr.bsmArgs.size != 3) return null
285+
val (interfaceMethodType, implementation, dynamicMethodType) = expr.bsmArgs
286+
287+
if (interfaceMethodType !is BsmMethodTypeArg) return null
288+
if (dynamicMethodType !is BsmMethodTypeArg) return null
289+
if (implementation !is BsmHandle) return null
290+
291+
// Check implementation signature match (starts with) call site arguments
292+
for ((index, argType) in expr.callSiteArgTypes.withIndex()) {
293+
if (argType != implementation.argTypes.getOrNull(index)) return null
284294
}
295+
296+
val klass = implementation.declaringClass.asType() as JcClassType
297+
val actualMethod = TypedMethodRefImpl(
298+
klass, implementation.name, implementation.argTypes, implementation.returnType
299+
)
300+
301+
return JcLambdaExpr(
302+
classpath.methodRef(expr),
303+
actualMethod,
304+
interfaceMethodType,
305+
dynamicMethodType,
306+
expr.callSiteMethodName,
307+
expr.callSiteArgTypes.map { it.asType() },
308+
expr.callSiteReturnType.asType(),
309+
expr.callSiteArgs.map { it.accept(this) as JcValue }
310+
)
285311
}
286312

287313
override fun visitJcRawVirtualCallExpr(expr: JcRawVirtualCallExpr): JcExpr {

jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/util/types.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ internal const val METHOD_HANDLE_CLASS = "Ljava.lang.invoke.MethodHandle;"
3131
internal const val METHOD_HANDLES_CLASS = "Ljava.lang.invoke.MethodHandles;"
3232
internal const val METHOD_HANDLES_LOOKUP_CLASS = "Ljava.lang.invoke.MethodHandles\$Lookup;"
3333
internal const val METHOD_TYPE_CLASS = "Ljava.lang.invoke.MethodType;"
34+
internal const val LAMBDA_METAFACTORY_CLASS = "Ljava.lang.invoke.LambdaMetafactory;"
3435
internal val TOP = "TOP".typeName()
3536
internal val UNINIT_THIS = "UNINIT_THIS".typeName()
3637

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2022 UnitTestBot contributors (utbot.org)
3+
* <p>
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.jacodb.testing.cfg
18+
19+
import org.jacodb.api.ext.findClass
20+
import org.jacodb.testing.WithDB
21+
import org.junit.jupiter.api.Assertions.assertEquals
22+
import org.junit.jupiter.api.Test
23+
24+
class InvokeDynamicTest : BaseInstructionsTest() {
25+
26+
companion object : WithDB()
27+
28+
@Test
29+
fun `test unary function`() = runStaticMethod<InvokeDynamicExamples>("testUnaryFunction")
30+
31+
@Test
32+
fun `test method ref unary function`() = runStaticMethod<InvokeDynamicExamples>("testMethodRefUnaryFunction")
33+
34+
@Test
35+
fun `test currying function`() = runStaticMethod<InvokeDynamicExamples>("testCurryingFunction")
36+
37+
@Test
38+
fun `test sam function`() = runStaticMethod<InvokeDynamicExamples>("testSamFunction")
39+
40+
@Test
41+
fun `test sam with default function`() = runStaticMethod<InvokeDynamicExamples>("testSamWithDefaultFunction")
42+
43+
@Test
44+
fun `test complex invoke dynamic`() = runStaticMethod<InvokeDynamicExamples>("testComplexInvokeDynamic")
45+
46+
private inline fun <reified T> runStaticMethod(name: String) {
47+
val clazz = cp.findClass<T>()
48+
49+
val javaClazz = testAndLoadClass(clazz)
50+
val method = javaClazz.methods.single { it.name == name }
51+
val res = method.invoke(null)
52+
assertEquals("OK", res)
53+
}
54+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2022 UnitTestBot contributors (utbot.org)
3+
* <p>
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.jacodb.testing.cfg;
18+
19+
import java.util.function.Function;
20+
21+
public class InvokeDynamicExamples {
22+
23+
private static int runUnaryFunction(String data, Function<String, Integer> f) {
24+
if (data.isEmpty()) {
25+
return -1;
26+
}
27+
28+
int result = f.apply(data);
29+
return result + 17;
30+
}
31+
32+
private static int runSamFunction(String data, SamBase f) {
33+
if (data.isEmpty()) {
34+
return -1;
35+
}
36+
int result = f.samFunction(data);
37+
return result + 17;
38+
}
39+
40+
private static int runDefaultFunction(String data, SamBase f) {
41+
if (data.isEmpty()) {
42+
return -1;
43+
}
44+
int result = f.defaultFunction(data);
45+
return result + 17;
46+
}
47+
48+
private static String runComplexStringConcat(String str, int v) {
49+
return str + v + 'x' + str + 17 + str;
50+
}
51+
52+
public interface SamBase {
53+
int samFunction(String data);
54+
55+
default int defaultFunction(String data) {
56+
if (data.isEmpty()) {
57+
return -2;
58+
}
59+
return samFunction(data) + 31;
60+
}
61+
}
62+
63+
private static int add(int a, int b) {
64+
return a + b;
65+
}
66+
67+
public static String testUnaryFunction() {
68+
int res = runUnaryFunction("abc", s -> s.length());
69+
return res == ("abc".length() + 17) ? "OK" : "BAD";
70+
}
71+
72+
public static String testMethodRefUnaryFunction() {
73+
int res = runUnaryFunction("abc", String::length);
74+
return res == ("abc".length() + 17) ? "OK" : "BAD";
75+
}
76+
77+
public static String testCurryingFunction() {
78+
Function<Integer, Integer> add42 = x -> add(x, 42);
79+
int res = runUnaryFunction("abc", s -> add42.apply(s.length()));
80+
return res == ("abc".length() + 17 + 42) ? "OK" : "BAD";
81+
}
82+
83+
public static String testSamFunction() {
84+
int res = runSamFunction("abc", s -> s.length());
85+
return res == ("abc".length() + 17) ? "OK" : "BAD";
86+
}
87+
88+
public static String testSamWithDefaultFunction() {
89+
int res = runDefaultFunction("abc", s -> s.length());
90+
return res == ("abc".length() + 17 + 31) ? "OK" : "BAD";
91+
}
92+
93+
public static String testComplexInvokeDynamic() {
94+
String expected = "abc42xabc17abc";
95+
String actual = runComplexStringConcat("abc", 42);
96+
return expected.equals(actual) ? "OK" : "BAD";
97+
}
98+
}

0 commit comments

Comments
 (0)