Skip to content

Commit 8d03290

Browse files
committed
Add support for NewWithFlag expression in AST and related components. Update parser, type checker, and IR generation to handle object allocation with GFP flags.
1 parent 0bb0472 commit 8d03290

14 files changed

Lines changed: 206 additions & 8 deletions

examples/simple_gfp_test.ks

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Simple test to verify GFP flag validation
2+
3+
struct TestData {
4+
value: u64,
5+
}
6+
7+
@kfunc
8+
fn valid_kfunc_allocation() -> i32 {
9+
// Basic allocation (valid in kernel context)
10+
var basic_ptr = new TestData(GFP_ATOMIC)
11+
delete basic_ptr
12+
return 0
13+
}
14+
15+
// This should succeed - basic allocation in eBPF context
16+
@xdp
17+
fn valid_ebpf_allocation(ctx: *xdp_md) -> xdp_action {
18+
var ptr = new TestData()
19+
delete ptr
20+
return XDP_PASS
21+
}
22+
23+
// This should succeed - basic allocation in userspace
24+
fn valid_userspace_allocation() -> i32 {
25+
var ptr = new TestData()
26+
delete ptr
27+
return 0
28+
}
29+
30+
fn main() -> i32 {
31+
return valid_userspace_allocation()
32+
}

src/ast.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ and expr_desc =
171171
| StructLiteral of string * (string * expr) list
172172
| Match of expr * match_arm list (* match (expr) { arms } *)
173173
| New of bpf_type (* new Type() - object allocation *)
174+
| NewWithFlag of bpf_type * expr (* new Type(gfp_flag) - object allocation with flag *)
174175

175176
(** Module function call *)
176177
and module_call = {
@@ -684,6 +685,8 @@ let rec string_of_expr expr =
684685
let arms_str = String.concat ",\n " (List.map string_of_match_arm arms) in
685686
Printf.sprintf "match (%s) {\n %s\n}" (string_of_expr expr) arms_str
686687
| New typ -> Printf.sprintf "new %s()" (string_of_bpf_type typ)
688+
| NewWithFlag (typ, flag_expr) ->
689+
Printf.sprintf "new %s(%s)" (string_of_bpf_type typ) (string_of_expr flag_expr)
687690

688691
and string_of_match_pattern = function
689692
| ConstantPattern lit -> string_of_literal lit

src/ebpf_c_codegen.ml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ let rec collect_string_sizes_from_instr ir_instr =
477477
(collect_string_sizes_from_value instance_val) @ (collect_string_sizes_from_value struct_ops_val)
478478
| IRObjectNew (dest_val, _) ->
479479
collect_string_sizes_from_value dest_val
480+
| IRObjectNewWithFlag (dest_val, _, flag_val) ->
481+
(collect_string_sizes_from_value dest_val) @ (collect_string_sizes_from_value flag_val)
480482
| IRObjectDelete ptr_val ->
481483
collect_string_sizes_from_value ptr_val
482484

@@ -2646,6 +2648,8 @@ let rec generate_c_instruction ctx ir_instr =
26462648
List.iter collect_in_value args
26472649
| IRObjectNew (dest_val, _) ->
26482650
collect_in_value dest_val
2651+
| IRObjectNewWithFlag (dest_val, _, flag_val) ->
2652+
collect_in_value dest_val; collect_in_value flag_val
26492653
| IRObjectDelete ptr_val ->
26502654
collect_in_value ptr_val
26512655
| IRJump _ | IRComment _ | IRBreak | IRContinue | IRThrow _ -> ()
@@ -2856,6 +2860,11 @@ let rec generate_c_instruction ctx ir_instr =
28562860
(* Use proper kernel pattern: ptr = bpf_obj_new(type) *)
28572861
emit_line ctx (sprintf "%s = bpf_obj_new(%s);" dest_str type_str)
28582862

2863+
| IRObjectNewWithFlag _ ->
2864+
(* GFP flags should never reach eBPF code generation - this is an internal error *)
2865+
failwith ("Internal error: GFP allocation flags are not supported in eBPF context. " ^
2866+
"This should have been caught by the type checker.")
2867+
28592868
| IRObjectDelete ptr_val ->
28602869
let ptr_str = generate_c_value ctx ptr_val in
28612870
(* Use the proper kernel bpf_obj_drop(ptr) macro *)
@@ -3067,6 +3076,8 @@ let collect_registers_in_function ir_func =
30673076
collect_in_value instance_val; collect_in_value struct_ops_val
30683077
| IRObjectNew (dest_val, _) ->
30693078
collect_in_value dest_val
3079+
| IRObjectNewWithFlag (dest_val, _, flag_val) ->
3080+
collect_in_value dest_val; collect_in_value flag_val
30703081
| IRObjectDelete ptr_val ->
30713082
collect_in_value ptr_val
30723083
in

src/evaluator.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,11 @@ and eval_expression ctx expr =
754754
(* For evaluator, object allocation returns a mock pointer value *)
755755
(* This is just for testing - real allocation happens in generated code *)
756756
PointerValue (Random.int 1000000)
757+
758+
| NewWithFlag (_, _) ->
759+
(* For evaluator, object allocation with flag returns a mock pointer value *)
760+
(* This is just for testing - real allocation happens in generated code *)
761+
PointerValue (Random.int 1000000)
757762

758763
(** Evaluate statements *)
759764
and eval_statements ctx stmts =

src/ir.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ and ir_instr_desc =
297297
| IRMapStore of ir_value * ir_value * ir_value * map_store_type
298298
| IRMapDelete of ir_value * ir_value
299299
| IRObjectNew of ir_value * ir_type (* target_pointer, object_type *)
300+
| IRObjectNewWithFlag of ir_value * ir_type * ir_value (* target_pointer, object_type, flag_expr *)
300301
| IRObjectDelete of ir_value (* pointer_to_delete *)
301302
| IRConfigFieldUpdate of ir_value * ir_value * string * ir_value (* map, key, field, value *)
302303
| IRStructFieldAssignment of ir_value * string * ir_value (* object, field, value *)
@@ -899,6 +900,8 @@ let rec string_of_ir_instruction instr =
899900
Printf.sprintf "delete(%s, %s)" (string_of_ir_value map) (string_of_ir_value key)
900901
| IRObjectNew (dest, obj_type) ->
901902
Printf.sprintf "%s = object_new(%s)" (string_of_ir_value dest) (string_of_ir_type obj_type)
903+
| IRObjectNewWithFlag (dest, obj_type, flag_expr) ->
904+
Printf.sprintf "%s = object_new(%s, %s)" (string_of_ir_value dest) (string_of_ir_type obj_type) (string_of_ir_value flag_expr)
902905
| IRObjectDelete ptr ->
903906
Printf.sprintf "object_delete(%s)" (string_of_ir_value ptr)
904907
| IRConfigFieldUpdate (map, key, field, value) ->

src/ir_generator.ml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,20 @@ let rec lower_expression ctx (expr : Ast.expr) =
937937
emit_instruction ctx alloc_instr;
938938

939939
result_val
940+
941+
| Ast.NewWithFlag (typ, flag_expr) ->
942+
(* Object allocation with GFP flag - only valid in kernel context *)
943+
let ir_type = ast_type_to_ir_type typ in
944+
let result_reg = allocate_register ctx in
945+
let result_val = make_ir_value (IRRegister result_reg) (IRPointer (ir_type, make_bounds_info ())) expr.expr_pos in
946+
947+
(* Lower the flag expression *)
948+
let flag_val = lower_expression ctx flag_expr in
949+
950+
let alloc_instr = make_ir_instruction (IRObjectNewWithFlag (result_val, ir_type, flag_val)) expr.expr_pos in
951+
emit_instruction ctx alloc_instr;
952+
953+
result_val
940954

941955
(** Helper function to handle register() builtin function calls *)
942956
and handle_register_builtin_call ctx args expr_pos ?target_register ?target_type () =
@@ -1335,7 +1349,7 @@ and lower_statement ctx stmt =
13351349

13361350
(* Handle function call and new expression declarations elegantly by proper instruction ordering *)
13371351
(match expr_opt with
1338-
| Some expr when (match expr.expr_desc with Ast.Call _ | Ast.New _ -> true | _ -> false) ->
1352+
| Some expr when (match expr.expr_desc with Ast.Call _ | Ast.New _ | Ast.NewWithFlag _ -> true | _ -> false) ->
13391353
(* For function calls and new expressions: emit declaration first, then operation with assignment *)
13401354
let target_type = match typ_opt with
13411355
| Some ast_type -> resolve_type_alias ctx reg ast_type
@@ -1348,6 +1362,9 @@ and lower_statement ctx stmt =
13481362
| Ast.New typ ->
13491363
let ir_type = ast_type_to_ir_type typ in
13501364
IRPointer (ir_type, make_bounds_info ())
1365+
| Ast.NewWithFlag (typ, _) ->
1366+
let ir_type = ast_type_to_ir_type typ in
1367+
IRPointer (ir_type, make_bounds_info ())
13511368
| _ -> IRU32))
13521369
in
13531370

@@ -1385,6 +1402,13 @@ and lower_statement ctx stmt =
13851402
let result_val = make_ir_value (IRRegister reg) target_type expr.Ast.expr_pos in
13861403
let alloc_instr = make_ir_instruction (IRObjectNew (result_val, ir_type)) expr.Ast.expr_pos in
13871404
emit_instruction ctx alloc_instr
1405+
| Ast.NewWithFlag (typ, flag_expr) ->
1406+
(* Handle new expression with flag: emit allocation instruction with flag *)
1407+
let ir_type = ast_type_to_ir_type typ in
1408+
let result_val = make_ir_value (IRRegister reg) target_type expr.Ast.expr_pos in
1409+
let flag_val = lower_expression ctx flag_expr in
1410+
let alloc_instr = make_ir_instruction (IRObjectNewWithFlag (result_val, ir_type, flag_val)) expr.Ast.expr_pos in
1411+
emit_instruction ctx alloc_instr
13881412
| _ -> ()) (* Shouldn't happen due to our guard *)
13891413
| _ ->
13901414
(* Non-function call declarations: use existing logic *)

src/kernel_module_codegen.ml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,14 @@ let rec generate_statement_translation stmt =
179179
sprintf " %s;" (generate_expression_translation expr)
180180
| Break -> " break;"
181181
| Continue -> " continue;"
182+
| Delete (DeletePointer ptr_expr) ->
183+
(* Translate pointer deletion to kfree() *)
184+
let ptr_str = generate_expression_translation ptr_expr in
185+
sprintf " kfree(%s);" ptr_str
186+
| Delete (DeleteMapEntry (map_expr, key_expr)) ->
187+
(* Map deletion not supported in kernel modules - this should be caught earlier *)
188+
sprintf " /* Map deletion not supported in kernel modules: delete %s[%s] */"
189+
(generate_expression_translation map_expr) (generate_expression_translation key_expr)
182190
| _ -> " /* TODO: Implement statement translation */"
183191

184192
(** Generate expression translation *)
@@ -262,6 +270,15 @@ and generate_expression_translation expr =
262270
sprintf "%s->%s" (generate_expression_translation obj) field
263271
| ArrayAccess (array, index) ->
264272
sprintf "%s[%s]" (generate_expression_translation array) (generate_expression_translation index)
273+
| New typ ->
274+
(* Basic allocation with GFP_KERNEL (default for kernel context) *)
275+
let c_type = kernelscript_type_to_c_type typ in
276+
sprintf "kmalloc(sizeof(%s), GFP_KERNEL)" c_type
277+
| NewWithFlag (typ, flag_expr) ->
278+
(* Allocation with specific GFP flag *)
279+
let c_type = kernelscript_type_to_c_type typ in
280+
let flag_str = generate_expression_translation flag_expr in
281+
sprintf "kmalloc(sizeof(%s), %s)" c_type flag_str
265282
| _ -> "/* TODO: Implement expression translation */"
266283

267284
(** Generate function implementation for regular kernel module functions *)

src/main.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,10 @@ let compile_source input_file output_dir _verbose generate_makefile btf_vmlinux_
740740
(List.length filtered_btf_declarations)
741741
(List.length btf_declarations - List.length filtered_btf_declarations);
742742

743-
let symbol_table = Symbol_table.build_symbol_table ~project_name:base_name ~builtin_asts:[filtered_btf_declarations] compilation_ast in
743+
(* Add stdlib builtin types to the symbol table *)
744+
let stdlib_builtin_declarations = Stdlib.get_builtin_types () in
745+
let all_builtin_declarations = stdlib_builtin_declarations @ filtered_btf_declarations in
746+
let symbol_table = Symbol_table.build_symbol_table ~project_name:base_name ~builtin_asts:[all_builtin_declarations] compilation_ast in
744747

745748
Printf.printf "✅ Symbol table created successfully with BTF types\n\n";
746749

src/parse.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ let validate_ast ast =
8686
| SingleExpr expr -> validate_expr expr
8787
| Block stmts -> List.for_all validate_stmt stmts
8888
) arms
89-
| New _ -> true (* New expressions are always syntactically valid *)
89+
| New _ -> true
90+
| NewWithFlag (_, flag_expr) -> validate_expr flag_expr (* New expressions are always syntactically valid *)
9091

9192
and validate_stmt stmt =
9293
match stmt.stmt_desc with

src/parser.mly

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ primary_expression:
433433
| primary_expression DOT IDENTIFIER { make_expr (FieldAccess ($1, $3)) (make_pos ()) }
434434
| primary_expression ARROW IDENTIFIER { make_expr (ArrowAccess ($1, $3)) (make_pos ()) }
435435
| NEW bpf_type LPAREN RPAREN { make_expr (New $2) (make_pos ()) }
436+
| NEW bpf_type LPAREN expression RPAREN { make_expr (NewWithFlag ($2, $4)) (make_pos ()) }
436437

437438
function_call:
438439
| IDENTIFIER LPAREN argument_list RPAREN

0 commit comments

Comments
 (0)