Skip to content

Commit 6372c0c

Browse files
committed
Add void function validation in expression context and enhance type checking.
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 1a4e243 commit 6372c0c

3 files changed

Lines changed: 444 additions & 11 deletions

File tree

src/type_checker.ml

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
open Ast
2020
open Printf
2121

22+
(** Expression context for void function validation *)
23+
type expr_context = Statement | Expression
24+
2225
(** Type checking exceptions *)
2326
exception Type_error of string * position
2427
exception Unification_error of bpf_type * bpf_type * position
@@ -40,6 +43,7 @@ type context = {
4043
mutable current_function: string option;
4144
mutable current_program_type: program_type option;
4245
mutable multi_program_analysis: Multi_program_analyzer.multi_program_analysis option;
46+
mutable expr_context: expr_context; (* Track whether we're in statement or expression context *)
4347
in_tail_call_context: bool; (* Flag to indicate we're processing a potential tail call *)
4448
in_match_return_context: bool; (* Flag to indicate we're inside a match expression in return position *)
4549
ast_context: Ast.declaration list; (* Store original AST for struct_ops attribute checking *)
@@ -184,6 +188,7 @@ let create_context symbol_table ast =
184188
current_function = None;
185189
current_program_type = None;
186190
multi_program_analysis = None;
191+
expr_context = Expression; (* Default to expression context for safety *)
187192
in_tail_call_context = false;
188193
in_match_return_context = false;
189194
attributed_function_map = Hashtbl.create 16;
@@ -196,6 +201,20 @@ let loop_depth = ref 0
196201
(** Helper to create type error *)
197202
let type_error msg pos = raise (Type_error (msg, pos))
198203

204+
(** Validate void function usage in expression context *)
205+
let validate_void_in_expression expr_type func_name context pos =
206+
match expr_type, context with
207+
| Void, Expression ->
208+
type_error ("Void function '" ^ func_name ^ "' cannot be used in an expression") pos
209+
| _ -> ()
210+
211+
(** Check if a type represents an enum (either Enum _ or built-in enum-like types) *)
212+
let is_enum_like_type = function
213+
| Enum _ -> true
214+
| Xdp_action -> true (* Built-in enum-like type *)
215+
(* Add other built-in enum-like types here as needed *)
216+
| _ -> false
217+
199218
(** Resolve user types to built-in types and type aliases *)
200219
let rec resolve_user_type ctx = function
201220
| UserType "xdp_md" -> Xdp_md
@@ -325,14 +344,15 @@ let rec unify_types t1 t2 =
325344
| ProgramRef pt1, ProgramRef pt2 when pt1 = pt2 -> Some (ProgramRef pt1)
326345

327346
(* Enum-integer compatibility: enums are represented as u32 *)
328-
| Enum _, U32 | U32, Enum _ -> Some U32
347+
| Enum _, (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) | (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64), Enum _ -> Some U32
329348
| Enum enum_name, Enum other_name when enum_name = other_name -> Some (Enum enum_name)
330349

331-
(* Special built-in type compatibility for specific enums *)
332-
| Enum "xdp_action", Xdp_action | Xdp_action, Enum "xdp_action" -> Some Xdp_action
333-
334350

335351

352+
(* All enum-like types (both Enum _ and built-in enum types) are compatible with integers *)
353+
| t1, (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) when is_enum_like_type t1 -> Some t1
354+
| (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64), t2 when is_enum_like_type t2 -> Some t2
355+
336356
(* No unification possible *)
337357
| _ -> None
338358

@@ -1237,14 +1257,20 @@ and type_check_expression ctx expr =
12371257

12381258
(* Try builtin -> user function -> function pointer variable *)
12391259
(match type_check_builtin_call ctx name typed_args arg_types expr.expr_pos with
1240-
| Some result -> result
1260+
| Some result ->
1261+
validate_void_in_expression result.texpr_type name ctx.expr_context expr.expr_pos;
1262+
result
12411263
| None ->
12421264
(match type_check_user_function_call ctx name typed_args arg_types expr.expr_pos with
1243-
| Some result -> result
1265+
| Some result ->
1266+
validate_void_in_expression result.texpr_type name ctx.expr_context expr.expr_pos;
1267+
result
12441268
| None ->
12451269
(match type_check_function_pointer_variable ctx name typed_args arg_types expr.expr_pos with
1246-
| Some result -> result
1247-
| None -> type_error ("Undefined function: " ^ name) expr.expr_pos)))
1270+
| Some result ->
1271+
validate_void_in_expression result.texpr_type name ctx.expr_context expr.expr_pos;
1272+
result
1273+
| None -> type_error ("Undefined function: " ^ name) expr.expr_pos)))
12481274

12491275
| FieldAccess ({expr_desc = Identifier var_name; _}, method_name)
12501276
when Hashtbl.mem ctx.variables var_name ->
@@ -1522,7 +1548,10 @@ and type_check_expression ctx expr =
15221548
and type_check_statement ctx stmt =
15231549
match stmt.stmt_desc with
15241550
| ExprStmt expr ->
1551+
let old_context = ctx.expr_context in
1552+
ctx.expr_context <- Statement; (* Allow void functions in statement context *)
15251553
let typed_expr = type_check_expression ctx expr in
1554+
ctx.expr_context <- old_context; (* Restore previous context *)
15261555
{ tstmt_desc = TExprStmt typed_expr; tstmt_pos = stmt.stmt_pos }
15271556

15281557
| Assignment (name, expr) ->
@@ -1965,6 +1994,32 @@ and type_check_statement ctx stmt =
19651994
None
19661995
| None -> None)
19671996
in
1997+
1998+
(* Elegant return validation: check compatibility with current function *)
1999+
(match ctx.current_function with
2000+
| Some func_name ->
2001+
(try
2002+
let (_, return_type) = Hashtbl.find ctx.functions func_name in
2003+
let resolved_return_type = resolve_user_type ctx return_type in
2004+
(match typed_expr_opt, resolved_return_type with
2005+
| Some _, Void ->
2006+
type_error ("Void function '" ^ func_name ^ "' cannot return a value") stmt.stmt_pos
2007+
| None, t when t <> Void ->
2008+
type_error ("Non-void function '" ^ func_name ^ "' must return a value of type " ^
2009+
string_of_bpf_type t) stmt.stmt_pos
2010+
| Some typed_expr, _ ->
2011+
(* Check return type compatibility *)
2012+
let resolved_expr_type = resolve_user_type ctx typed_expr.texpr_type in
2013+
(match unify_types resolved_expr_type resolved_return_type with
2014+
| Some _ -> () (* Types can be unified *)
2015+
| None ->
2016+
type_error ("Function '" ^ func_name ^ "' expects return type " ^
2017+
string_of_bpf_type resolved_return_type ^ " but got " ^
2018+
string_of_bpf_type resolved_expr_type) stmt.stmt_pos)
2019+
| _ -> () (* Valid cases *))
2020+
with Not_found -> () (* Function not in context *))
2021+
| None -> () (* Not in function context *));
2022+
19682023
{ tstmt_desc = TReturn typed_expr_opt; tstmt_pos = stmt.stmt_pos }
19692024

19702025
| If (cond, then_stmts, else_opt) ->

tests/dune

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,11 @@
421421
(modules test_exec)
422422
(libraries kernelscript alcotest unix str))
423423

424+
(executable
425+
(name test_void_functions)
426+
(modules test_void_functions)
427+
(libraries kernelscript alcotest str test_utils))
428+
424429
; Test rules for individual execution
425430
(rule
426431
(alias runtest_lexer)
@@ -666,6 +671,10 @@
666671
(alias runtest_exec)
667672
(action (run ./test_exec.exe)))
668673

674+
(rule
675+
(alias runtest_void_functions)
676+
(action (run ./test_void_functions.exe)))
677+
669678
; Test aliases for organized execution
670679
(rule
671680
(alias map-tests)
@@ -767,7 +776,8 @@
767776
test_match.exe
768777
test_test_attribute.exe
769778
test_named_returns.exe
770-
test_import_system.exe)
779+
test_import_system.exe
780+
test_void_functions.exe)
771781
(action
772782
(progn
773783
(run ./test_btf_binary_parser.exe)
@@ -817,7 +827,8 @@
817827
(run ./test_match.exe)
818828
(run ./test_test_attribute.exe)
819829
(run ./test_named_returns.exe)
820-
(run ./test_import_system.exe))))
830+
(run ./test_import_system.exe)
831+
(run ./test_void_functions.exe))))
821832

822833
(rule
823834
(alias ir-tests)
@@ -874,7 +885,8 @@
874885
./test_tracepoint.exe
875886
./test_probe.exe
876887
./test_tc.exe
877-
./test_exec.exe)
888+
./test_exec.exe
889+
./test_void_functions.exe)
878890
(action
879891
(progn
880892
(run ./test_ast.exe)
@@ -895,4 +907,5 @@
895907
(run ./test_probe.exe)
896908
(run ./test_tc.exe)
897909
(run ./test_exec.exe)
910+
(run ./test_void_functions.exe)
898911
(run ./test_detach_api.exe))))

0 commit comments

Comments
 (0)