Skip to content

Commit 769f2e8

Browse files
committed
Fix function calls in variable initializers return directly to the variable's register
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent e9178bf commit 769f2e8

4 files changed

Lines changed: 185 additions & 14 deletions

File tree

src/ebpf_c_codegen.ml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2773,8 +2773,35 @@ let generate_c_basic_block ctx ir_block =
27732773
increase_indent ctx
27742774
);
27752775

2776-
(* Emit instructions *)
2777-
List.iter (generate_c_instruction ctx) ir_block.instructions
2776+
(* Emit instructions with special handling for function call + variable declaration sequences *)
2777+
let rec emit_instructions_optimized = function
2778+
| [] -> ()
2779+
| call_instr :: decl_instr :: rest
2780+
when (match call_instr.instr_desc, decl_instr.instr_desc with
2781+
| IRCall (call_target, args, Some result_val), IRDeclareVariable (dest_val, typ, None) ->
2782+
(* Check if the call result register matches the declaration register *)
2783+
(match result_val.value_desc, dest_val.value_desc with
2784+
| IRRegister call_reg, IRRegister decl_reg when call_reg = decl_reg ->
2785+
(* Generate combined declaration with function call initialization *)
2786+
let var_name = get_meaningful_var_name ctx decl_reg typ in
2787+
let args_str = String.concat ", " (List.map (generate_c_value ctx) args) in
2788+
let call_str = match call_target with
2789+
| DirectCall name -> sprintf "%s(%s)" name args_str
2790+
| FunctionPointerCall func_ptr -> sprintf "(*%s)(%s)" (generate_c_value ctx func_ptr) args_str
2791+
in
2792+
let decl_str = generate_ebpf_c_declaration typ var_name in
2793+
emit_line ctx (sprintf "%s = %s;" decl_str call_str);
2794+
true
2795+
| _ -> false)
2796+
| _ -> false) ->
2797+
(* Skip the next instruction since we handled it in the combined declaration *)
2798+
emit_instructions_optimized rest
2799+
| instr :: rest ->
2800+
(* Regular instruction processing *)
2801+
generate_c_instruction ctx instr;
2802+
emit_instructions_optimized rest
2803+
in
2804+
emit_instructions_optimized ir_block.instructions
27782805

27792806
(** Collect mapping from registers to variable names *)
27802807
let collect_register_variable_mapping ir_func =

src/ir_generator.ml

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,20 +1017,67 @@ and resolve_declaration_type_and_init ctx reg typ_opt expr_opt =
10171017
| Some ast_type, Some expr ->
10181018
(* Use explicitly declared type, but process initialization expression *)
10191019
let target_type = resolve_type_alias ctx reg ast_type in
1020-
let value = lower_expression ctx expr in
1021-
(target_type, Some value)
1020+
(* For function calls, manually handle them to use the target register *)
1021+
(match expr.Ast.expr_desc with
1022+
| Ast.Call (callee_expr, args) ->
1023+
(* Handle function call that should return to the target register *)
1024+
let arg_vals = List.map (lower_expression ctx) args in
1025+
let result_val = make_ir_value (IRRegister reg) target_type expr.Ast.expr_pos in
1026+
let call_target = match callee_expr.Ast.expr_desc with
1027+
| Ast.Identifier name ->
1028+
if Hashtbl.mem ctx.variables name || Hashtbl.mem ctx.function_parameters name then
1029+
let callee_val = lower_expression ctx callee_expr in
1030+
FunctionPointerCall callee_val
1031+
else
1032+
DirectCall name
1033+
| _ ->
1034+
let callee_val = lower_expression ctx callee_expr in
1035+
FunctionPointerCall callee_val
1036+
in
1037+
let instr = make_ir_instruction (IRCall (call_target, arg_vals, Some result_val)) expr.Ast.expr_pos in
1038+
emit_instruction ctx instr;
1039+
(target_type, None)
1040+
| _ ->
1041+
(* Non-function call - use normal processing *)
1042+
let value = lower_expression ctx expr in
1043+
(target_type, Some value))
10221044
| None, Some expr ->
10231045
(* No declared type - use type checker annotation if available, otherwise infer from expression *)
1024-
let value = lower_expression ctx expr in
1025-
let inferred_type = match expr.expr_type with
1026-
| Some ast_type ->
1027-
(* Prioritize type checker annotation as single source of truth *)
1028-
ast_type_to_ir_type_with_context ctx.symbol_table ast_type
1029-
| None ->
1030-
(* Fallback to IR type inference only when type checker didn't provide annotation *)
1031-
value.val_type
1032-
in
1033-
(inferred_type, Some value)
1046+
(match expr.Ast.expr_desc with
1047+
| Ast.Call (callee_expr, args) ->
1048+
(* Handle function call in type inference *)
1049+
let inferred_type = match expr.Ast.expr_type with
1050+
| Some ast_type -> ast_type_to_ir_type_with_context ctx.symbol_table ast_type
1051+
| None -> IRU32 (* Default fallback *)
1052+
in
1053+
let arg_vals = List.map (lower_expression ctx) args in
1054+
let result_val = make_ir_value (IRRegister reg) inferred_type expr.Ast.expr_pos in
1055+
let call_target = match callee_expr.Ast.expr_desc with
1056+
| Ast.Identifier name ->
1057+
if Hashtbl.mem ctx.variables name || Hashtbl.mem ctx.function_parameters name then
1058+
let callee_val = lower_expression ctx callee_expr in
1059+
FunctionPointerCall callee_val
1060+
else
1061+
DirectCall name
1062+
| _ ->
1063+
let callee_val = lower_expression ctx callee_expr in
1064+
FunctionPointerCall callee_val
1065+
in
1066+
let instr = make_ir_instruction (IRCall (call_target, arg_vals, Some result_val)) expr.Ast.expr_pos in
1067+
emit_instruction ctx instr;
1068+
(inferred_type, None)
1069+
| _ ->
1070+
(* Non-function call - use normal processing *)
1071+
let value = lower_expression ctx expr in
1072+
let inferred_type = match expr.Ast.expr_type with
1073+
| Some ast_type ->
1074+
(* Prioritize type checker annotation as single source of truth *)
1075+
ast_type_to_ir_type_with_context ctx.symbol_table ast_type
1076+
| None ->
1077+
(* Fallback to IR type inference only when type checker didn't provide annotation *)
1078+
value.val_type
1079+
in
1080+
(inferred_type, Some value))
10341081
| Some ast_type, None ->
10351082
(* Declared type, no initialization *)
10361083
let target_type = resolve_type_alias ctx reg ast_type in

tests/test_ebpf_c_codegen.ml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,31 @@ let test_map_field_access_pointer_fix () =
925925
check bool "dot notation used for regular struct field access" true (contains_substr regular_result "my_struct.size");
926926
()
927927

928+
(** Test variable declaration with function call initialization *)
929+
let test_variable_function_call_declaration () =
930+
let ctx = create_c_context () in
931+
ctx.indent_level <- 1; (* Set valid indent level *)
932+
933+
(* Create a function call that returns to a register *)
934+
let result_reg = 0 in
935+
let result_val = make_ir_value (IRRegister result_reg) IRU32 test_pos in
936+
let call_instr = make_ir_instruction (IRCall (DirectCall "helper_function", [make_ir_value (IRLiteral (IntLit (5, None))) IRU32 test_pos], Some result_val)) test_pos in
937+
938+
(* Create a variable declaration for the same register with no initialization *)
939+
let decl_instr = make_ir_instruction (IRDeclareVariable (result_val, IRU32, None)) test_pos in
940+
941+
(* Test the optimization that combines these into a single declaration *)
942+
let ir_block = make_ir_basic_block "test" [call_instr; decl_instr] 0 in
943+
generate_c_basic_block ctx ir_block;
944+
945+
let output = String.concat "\n" ctx.output_lines in
946+
947+
(* Should generate: __u32 val_0 = helper_function(5); *)
948+
check bool "combined declaration with function call" true (contains_substr output "val_0 = helper_function(5)");
949+
950+
(* Should NOT generate separate variable declaration without initialization *)
951+
check bool "no uninitialized declaration" false (contains_substr output "__u32 val_0;")
952+
928953
(** Test suite definition *)
929954
let suite =
930955
[
@@ -962,6 +987,7 @@ let suite =
962987
("String size collection from userspace structs", `Quick, test_string_size_collection_from_userspace_structs);
963988
("Declaration ordering fix", `Quick, test_declaration_ordering_fix);
964989
("BPF printk string literal fix", `Quick, test_bpf_printk_string_literal_fix);
990+
("Variable function call declaration", `Quick, test_variable_function_call_declaration);
965991
]
966992

967993
(** Run all tests *)

tests/test_ir.ml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,84 @@ let test_userspace_binding_generation () =
208208
| None ->
209209
fail "No C bindings found"
210210

211+
let test_variable_function_call_initialization () =
212+
(* Test for the bug where function calls in variable initializers
213+
return to wrong registers, causing uninitialized variable usage *)
214+
let input = {|
215+
@xdp fn test_handler(ctx: *xdp_md) -> xdp_action {
216+
return 2 // XDP_PASS
217+
}
218+
219+
fn main() -> i32 {
220+
var prog = load(test_handler) // Should assign to same register as 'prog'
221+
var result = attach(prog, "eth0", 0) // Should use 'prog' register correctly
222+
return result
223+
}
224+
|} in
225+
226+
try
227+
let ast = Kernelscript.Parse.parse_string input in
228+
let symbol_table = Kernelscript.Symbol_table.build_symbol_table ast in
229+
let (typed_ast, _) = Kernelscript.Type_checker.type_check_and_annotate_ast ~symbol_table:(Some symbol_table) ast in
230+
let ir_multi_prog = generate_ir typed_ast symbol_table "test_var_func_init" in
231+
232+
(* Extract the main function from userspace program *)
233+
let userspace_program = match ir_multi_prog.userspace_program with
234+
| Some prog -> prog
235+
| None -> failwith "No userspace program found"
236+
in
237+
let main_func = List.find (fun func -> func.func_name = "main") userspace_program.userspace_functions in
238+
239+
(* Collect all instructions from all basic blocks *)
240+
let all_instructions = List.flatten (List.map (fun block -> block.instructions) main_func.basic_blocks) in
241+
242+
(* Find variable declarations and function calls *)
243+
let declarations = List.filter_map (fun instr ->
244+
match instr.instr_desc with
245+
| IRDeclareVariable (dest_val, _, _) -> Some dest_val
246+
| _ -> None
247+
) all_instructions in
248+
249+
let function_calls = List.filter_map (fun instr ->
250+
match instr.instr_desc with
251+
| IRCall (_, _, Some result_val) -> Some result_val
252+
| _ -> None
253+
) all_instructions in
254+
255+
(* Verify we have the expected number of declarations and calls *)
256+
check int "Should have variable declarations" 2 (List.length declarations);
257+
check int "Should have function calls" 2 (List.length function_calls);
258+
259+
(* The key test: verify that function call returns go to the same registers as variable declarations *)
260+
let get_register_from_value val_desc = match val_desc with
261+
| IRRegister reg -> Some reg
262+
| _ -> None
263+
in
264+
265+
let declaration_registers = List.filter_map (fun val_desc -> get_register_from_value val_desc.value_desc) declarations in
266+
let call_result_registers = List.filter_map (fun val_desc -> get_register_from_value val_desc.value_desc) function_calls in
267+
268+
(* Verify that function call results use the same registers as variable declarations *)
269+
(* This catches the bug where function calls returned to different registers *)
270+
check bool "Function call results should use declaration registers" true
271+
(List.for_all (fun reg -> List.mem reg declaration_registers) call_result_registers);
272+
273+
(* Verify register consistency - each variable should map to exactly one register *)
274+
let sorted_decl_regs = List.sort compare declaration_registers in
275+
let sorted_call_regs = List.sort compare call_result_registers in
276+
check (list int) "Declaration and call registers should match" sorted_decl_regs sorted_call_regs
277+
278+
with
279+
| e -> failwith (Printf.sprintf "Variable function call initialization test failed: %s" (Printexc.to_string e))
280+
211281
let ir_tests = [
212282
"program_lowering", `Quick, test_program_lowering;
213283
"context_access_lowering", `Quick, test_context_access_lowering;
214284
"map_operation_lowering", `Quick, test_map_operation_lowering;
215285
"bounds_check_insertion", `Quick, test_bounds_check_insertion;
216286
"stack_usage_tracking", `Quick, test_stack_usage_tracking;
217287
"userspace_binding_generation", `Quick, test_userspace_binding_generation;
288+
"variable_function_call_initialization", `Quick, test_variable_function_call_initialization;
218289
]
219290

220291
let () =

0 commit comments

Comments
 (0)