Skip to content

Commit b8f9541

Browse files
committed
Refactor type definitions in AST and IR to remove kernel-defined flags from structs and enums. Update related functions and tests to reflect these changes, ensuring consistency across the codebase.
1 parent 441c2ae commit b8f9541

18 files changed

Lines changed: 213 additions & 167 deletions

src/ast.ml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ type map_flag =
5252

5353
(** Type definitions for structs, enums, and type aliases *)
5454
type type_def =
55-
| StructDef of string * (string * bpf_type) list * bool
56-
| EnumDef of string * (string * int option) list * bool
55+
| StructDef of string * (string * bpf_type) list
56+
| EnumDef of string * (string * int option) list
5757
| TypeAlias of string * bpf_type
5858

5959
(** BPF type system with extended type definitions *)
@@ -267,7 +267,6 @@ and struct_def = {
267267
struct_fields: (string * bpf_type) list;
268268
struct_attributes: attribute list; (* Added attributes for @struct_ops etc. *)
269269
struct_pos: position;
270-
kernel_defined: bool; (* NEW: Mark if this struct is kernel-defined *)
271270
}
272271

273272
(** Program definition with local maps and structs *)
@@ -434,11 +433,11 @@ let make_attributed_function attrs func pos = {
434433

435434
let make_type_def def = def
436435

437-
let make_enum_def name values = EnumDef (name, values, false) (* Default to user-defined *)
436+
let make_enum_def name values = EnumDef (name, values)
438437

439-
let make_kernel_enum_def name values = EnumDef (name, values, true) (* Mark as kernel-defined *)
438+
let make_kernel_enum_def name values = EnumDef (name, values)
440439

441-
let make_kernel_struct_def name fields = StructDef (name, fields, true) (* Mark as kernel-defined *)
440+
let make_kernel_struct_def name fields = StructDef (name, fields)
442441

443442
let make_type_alias name bpf_type = TypeAlias (name, bpf_type)
444443

@@ -466,7 +465,6 @@ let make_struct_def ?(attributes=[]) name fields pos = {
466465
struct_fields = fields;
467466
struct_attributes = attributes;
468467
struct_pos = pos;
469-
kernel_defined = false;
470468
}
471469

472470
let make_config_field name field_type default pos = {
@@ -806,11 +804,11 @@ let string_of_declaration = function
806804
| GlobalFunction func -> string_of_function func
807805
| TypeDef td ->
808806
let type_str = match td with
809-
| StructDef (name, fields, _) ->
807+
| StructDef (name, fields) ->
810808
Printf.sprintf "struct %s {\n %s\n}" name
811809
(String.concat "\n " (List.map (fun (name, typ) ->
812810
Printf.sprintf "%s: %s;" name (string_of_bpf_type typ)) fields))
813-
| EnumDef (name, values, _) ->
811+
| EnumDef (name, values) ->
814812
Printf.sprintf "enum %s {\n %s\n}" name
815813
(String.concat ",\n " (List.map (fun (name, opt) ->
816814
match opt with

src/ebpf_c_codegen.ml

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ let rec calculate_type_size ir_type =
207207
(* These types should never appear in field assignments due to type checking *)
208208
| IRVoid ->
209209
failwith "calculate_type_size: IRVoid should not appear in field assignments"
210-
| IRStruct (struct_name, _, _) ->
210+
| IRStruct (struct_name, _) ->
211211
failwith ("calculate_type_size: IRStruct should not appear in field assignments, got: " ^ struct_name)
212-
| IREnum (enum_name, _, _) ->
212+
| IREnum (enum_name, _) ->
213213
failwith ("calculate_type_size: IREnum should not appear in field assignments, got: " ^ enum_name)
214214
| IRResult (_, _) ->
215215
failwith "calculate_type_size: IRResult should not appear in field assignments"
@@ -315,8 +315,8 @@ let rec ebpf_type_from_ir_type = function
315315
| IRStr size -> sprintf "str_%d_t" size
316316
| IRPointer (inner_type, _) -> sprintf "%s*" (ebpf_type_from_ir_type inner_type)
317317
| IRArray (inner_type, size, _) -> sprintf "%s[%d]" (ebpf_type_from_ir_type inner_type) size
318-
| IRStruct (name, _, _) -> sprintf "struct %s" name
319-
| IREnum (name, _, _) -> sprintf "enum %s" name
318+
| IRStruct (name, _) -> sprintf "struct %s" name
319+
| IREnum (name, _) -> sprintf "enum %s" name
320320

321321
| IRResult (ok_type, _err_type) -> ebpf_type_from_ir_type ok_type (* simplified to ok type *)
322322
| IRTypeAlias (name, _) -> name (* Use the alias name directly *)
@@ -567,7 +567,7 @@ let collect_enum_definitions ?symbol_table ir_multi_prog =
567567
let enum_map = Hashtbl.create 16 in
568568

569569
let rec collect_from_type = function
570-
| IREnum (name, values, _) -> Hashtbl.replace enum_map name values
570+
| IREnum (name, values) -> Hashtbl.replace enum_map name values
571571
| IRPointer (inner_type, _) -> collect_from_type inner_type
572572
| IRArray (inner_type, _, _) -> collect_from_type inner_type
573573

@@ -655,7 +655,7 @@ let collect_enum_definitions ?symbol_table ir_multi_prog =
655655
let global_symbols = Symbol_table.get_global_symbols st in
656656
List.iter (fun symbol ->
657657
match symbol.Symbol_table.kind with
658-
| Symbol_table.TypeDef (Ast.EnumDef (enum_name, enum_values, _kernel_defined)) ->
658+
| Symbol_table.TypeDef (Ast.EnumDef (enum_name, enum_values)) ->
659659
let processed_values = List.map (fun (const_name, opt_value) ->
660660
(const_name, Option.value ~default:0 opt_value)
661661
) enum_values in
@@ -726,7 +726,7 @@ let collect_struct_definitions_from_multi_program ir_multi_prog =
726726

727727
let rec collect_from_type ir_type =
728728
match ir_type with
729-
| IRStruct (name, struct_fields, _) ->
729+
| IRStruct (name, struct_fields) ->
730730
if not (List.mem_assoc name !struct_defs) then (
731731
(* Only collect structs that actually have fields - ignore empty structs that are likely type aliases *)
732732
if struct_fields <> [] then
@@ -852,19 +852,21 @@ let collect_struct_definitions_from_multi_program ir_multi_prog =
852852

853853
(** Generate struct definitions *)
854854
let generate_struct_definitions ctx struct_defs =
855-
(* Filter out kernel-defined structs based on IR kernel_defined flag *)
856-
let user_defined_structs = List.filter (fun (_struct_name, fields) ->
857-
(* Check if this struct itself is kernel-defined or has kernel-defined fields *)
855+
(* Filter out kernel-defined structs using centralized kernel type knowledge *)
856+
let user_defined_structs = List.filter (fun (struct_name, fields) ->
857+
(* Check if this struct name is a well-known kernel type *)
858+
let is_kernel_type = Kernel_types.is_well_known_ebpf_type struct_name in
859+
860+
(* Check if this struct has kernel-defined field types *)
858861
let has_kernel_field = List.exists (fun (_field_name, field_type) ->
859862
match field_type with
860-
| IRStruct (_, _, kernel_defined) -> kernel_defined
861-
| IREnum (_, _, kernel_defined) -> kernel_defined
863+
| IRStruct (name, _) -> Kernel_types.is_well_known_ebpf_type name
864+
| IREnum (name, _) -> Kernel_types.is_well_known_ebpf_type name
862865
| _ -> false
863866
) fields in
864867

865-
(* Only filter based on kernel_defined flag from IR, not struct name *)
866-
(* User-defined structs should be generated regardless of their name *)
867-
not has_kernel_field
868+
(* Filter out kernel types and structs with kernel-defined fields *)
869+
not is_kernel_type && not has_kernel_field
868870
) struct_defs in
869871

870872
if user_defined_structs <> [] then (
@@ -1741,7 +1743,7 @@ let generate_c_expression ctx ir_expr =
17411743
| PacketData ->
17421744
(* Packet data field access - use bpf_dynptr_from_xdp *)
17431745
(match obj_val.val_type with
1744-
| IRPointer (IRStruct (struct_name, _, _), _) ->
1746+
| IRPointer (IRStruct (struct_name, _), _) ->
17451747
(* Note: For field ACCESS (not assignment), we use sizeof(__typeof(field))
17461748
which is calculated by the C compiler, so we don't need calculate_type_size here *)
17471749
let field_size = sprintf "sizeof(__typeof(((%s*)0)->%s))"
@@ -1754,7 +1756,7 @@ let generate_c_expression ctx ir_expr =
17541756
| _ when is_map_value_parameter obj_val ->
17551757
(* Map value field access - use bpf_dynptr_from_mem *)
17561758
(match obj_val.val_type with
1757-
| IRPointer (IRStruct (struct_name, _, _), _) ->
1759+
| IRPointer (IRStruct (struct_name, _), _) ->
17581760
(* Note: For field ACCESS (not assignment), we use sizeof(__typeof(field))
17591761
which is calculated by the C compiler, so we don't need calculate_type_size here *)
17601762
let field_size = sprintf "sizeof(__typeof(((%s*)0)->%s))"
@@ -2054,7 +2056,7 @@ let generate_ringbuf_operation ctx ringbuf_val op =
20542056

20552057
(* Calculate proper size based on the result type *)
20562058
let size = match result_val.val_type with
2057-
| IRPointer (IRStruct (struct_name, _, _), _) ->
2059+
| IRPointer (IRStruct (struct_name, _), _) ->
20582060
(* Use sizeof for struct types *)
20592061
sprintf "sizeof(struct %s)" struct_name
20602062
| IRPointer (elem_type, _) ->
@@ -2066,7 +2068,7 @@ let generate_ringbuf_operation ctx ringbuf_val op =
20662068
(match other_type with
20672069
| IRU32 -> "IRU32"
20682070
| IRU64 -> "IRU64"
2069-
| IRStruct (name, _, _) -> "IRStruct " ^ name
2071+
| IRStruct (name, _) -> "IRStruct " ^ name
20702072
| IRVoid -> "IRVoid"
20712073
| _ -> "unknown type"))
20722074
in
@@ -2318,7 +2320,7 @@ let generate_truthy_conversion ctx ir_value =
23182320
| IRPointer (_, _) ->
23192321
(* Pointers: null is falsy, non-null is truthy *)
23202322
sprintf "(%s != NULL)" (generate_c_value ctx ir_value)
2321-
| IREnum (_, _, _) ->
2323+
| IREnum (_, _) ->
23222324
(* Enums: based on numeric value *)
23232325
sprintf "(%s != 0)" (generate_c_value ctx ir_value)
23242326
| _ ->
@@ -2508,10 +2510,10 @@ let rec generate_c_instruction ctx ir_instr =
25082510
(* Check if this is a dynptr-backed pointer first *)
25092511
(match Hashtbl.find_opt ctx.dynptr_backed_pointers obj_str with
25102512
| Some dynptr_var ->
2511-
(* This is a dynptr-backed pointer - use bpf_dynptr_write *)
2512-
let field_size = calculate_type_size value_val.val_type in
2513-
(match obj_val.val_type with
2514-
| IRPointer (IRStruct (struct_name, _, _), _) ->
2513+
(* This is a dynptr-backed pointer - use bpf_dynptr_write *)
2514+
let field_size = calculate_type_size value_val.val_type in
2515+
(match obj_val.val_type with
2516+
| IRPointer (IRStruct (struct_name, _), _) ->
25152517
let full_struct_name = sprintf "struct %s" struct_name in
25162518
emit_line ctx (sprintf "{ %s __tmp_val = %s;" (ebpf_type_from_ir_type value_val.val_type) value_str);
25172519
emit_line ctx (sprintf " bpf_dynptr_write(&%s, __builtin_offsetof(%s, %s), &__tmp_val, %d, 0); }"
@@ -2524,8 +2526,8 @@ let rec generate_c_instruction ctx ir_instr =
25242526
(match detect_memory_region_enhanced obj_val with
25252527
| PacketData ->
25262528
(* Packet data field assignment - use dynptr API for safe write *)
2527-
(match obj_val.val_type with
2528-
| IRPointer (IRStruct (struct_name, _, _), _) ->
2529+
(match obj_val.val_type with
2530+
| IRPointer (IRStruct (struct_name, _), _) ->
25292531
let field_size = calculate_type_size value_val.val_type in
25302532
let full_struct_name = sprintf "struct %s" struct_name in
25312533
emit_line ctx (sprintf "{ struct bpf_dynptr __pkt_dynptr; bpf_dynptr_from_xdp(&__pkt_dynptr, ctx);");
@@ -2537,7 +2539,7 @@ let rec generate_c_instruction ctx ir_instr =
25372539
| _ when is_map_value_parameter obj_val ->
25382540
(* Map value field assignment - use dynptr API *)
25392541
(match obj_val.val_type with
2540-
| IRPointer (IRStruct (struct_name, _, _), _) ->
2542+
| IRPointer (IRStruct (struct_name, _), _) ->
25412543
let field_size = calculate_type_size value_val.val_type in
25422544
let full_struct_name = sprintf "struct %s" struct_name in
25432545
emit_line ctx (sprintf "{ struct bpf_dynptr __mem_dynptr; bpf_dynptr_from_mem(%s, sizeof(%s), 0, &__mem_dynptr);" obj_str full_struct_name);
@@ -3336,8 +3338,8 @@ let generate_c_function ctx ir_func =
33363338
| (_, IRPointer (IRContext XdpCtx, _)) :: _ -> Some "xdp"
33373339
| (_, IRPointer (IRContext TcCtx, _)) :: _ -> Some "tc"
33383340
| (_, IRPointer (IRContext KprobeCtx, _)) :: _ -> Some "kprobe"
3339-
| (_, IRPointer (IRStruct ("__sk_buff", _, _), _)) :: _ -> Some "tc" (* Handle __sk_buff as TC context *)
3340-
| (_, IRPointer (IRStruct ("xdp_md", _, _), _)) :: _ -> Some "xdp" (* Handle xdp_md as XDP context *)
3341+
| (_, IRPointer (IRStruct ("__sk_buff", _), _)) :: _ -> Some "tc" (* Handle __sk_buff as TC context *)
3342+
| (_, IRPointer (IRStruct ("xdp_md", _), _)) :: _ -> Some "xdp" (* Handle xdp_md as XDP context *)
33413343
| _ -> None));
33423344

33433345
let return_type_str =
@@ -3383,7 +3385,7 @@ let generate_c_function ctx ir_func =
33833385
| (_, IRPointer (IRContext TcCtx, _)) :: _ -> "SEC(\"tc\")"
33843386
| (_, IRPointer (IRContext KprobeCtx, _)) :: _ -> "SEC(\"kprobe\")"
33853387
| (_, IRPointer (IRContext TracepointCtx, _)) :: _ -> "SEC(\"tracepoint\")"
3386-
| (_, IRPointer (IRStruct ("__sk_buff", _, _), _)) :: _ -> "SEC(\"tc\")" (* Handle __sk_buff as TC context *)
3388+
| (_, IRPointer (IRStruct ("__sk_buff", _), _)) :: _ -> "SEC(\"tc\")" (* Handle __sk_buff as TC context *)
33873389
| _ -> "SEC(\"prog\")"
33883390
else ""
33893391
in

src/import_resolver.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,15 @@ let extract_exportable_symbols ast =
141141

142142
| TypeDef type_def ->
143143
(match type_def with
144-
| StructDef (name, _fields, _) ->
144+
| StructDef (name, _fields) ->
145145
let struct_type = Struct name in
146146
symbols := {
147147
symbol_name = name;
148148
symbol_type = struct_type;
149149
symbol_kind = `Type;
150150
is_public = true;
151151
} :: !symbols
152-
| EnumDef (name, _, _) ->
152+
| EnumDef (name, _) ->
153153
let enum_type = Enum name in
154154
symbols := {
155155
symbol_name = name;

src/ir.ml

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ and ir_struct_def = {
114114
struct_alignment: int; (* Memory alignment requirements *)
115115
struct_size: int; (* Total struct size in bytes *)
116116
struct_pos: ir_position;
117-
kernel_defined: bool; (* NEW: Mark if this struct is kernel-defined *)
118117
}
119118

120119
(** Userspace configuration in IR *)
@@ -135,8 +134,8 @@ and ir_type =
135134
| IRStr of int (* Fixed-size string str<N> *)
136135
| IRPointer of ir_type * bounds_info
137136
| IRArray of ir_type * int * bounds_info
138-
| IRStruct of string * (string * ir_type) list * bool (* NEW: bool for kernel_defined *)
139-
| IREnum of string * (string * int) list * bool (* NEW: bool for kernel_defined *)
137+
| IRStruct of string * (string * ir_type) list
138+
| IREnum of string * (string * int) list
140139
| IRResult of ir_type * ir_type
141140
| IRContext of context_type
142141
| IRAction of action_type
@@ -604,7 +603,6 @@ let make_ir_struct_def name fields alignment size pos = {
604603
struct_alignment = alignment;
605604
struct_size = size;
606605
struct_pos = pos;
607-
kernel_defined = false;
608606
}
609607

610608
let make_ir_config_item key value config_type = {
@@ -707,8 +705,8 @@ let rec ast_type_to_ir_type = function
707705
let bounds = make_bounds_info ~nullable:true () in
708706
IRPointer (ast_type_to_ir_type t, bounds)
709707
| Struct "__sk_buff" -> IRContext TcCtx (* Map __sk_buff to TC context *)
710-
| Struct name -> IRStruct (name, [], false) (* Fields filled by symbol table, default to user-defined *)
711-
| Enum name -> IREnum (name, [], false) (* Values filled by symbol table, default to user-defined *)
708+
| Struct name -> IRStruct (name, []) (* Fields filled by symbol table *)
709+
| Enum name -> IREnum (name, []) (* Values filled by symbol table *)
712710
| Option t ->
713711
let bounds = make_bounds_info ~nullable:true () in
714712
IRPointer (ast_type_to_ir_type t, bounds)
@@ -719,7 +717,7 @@ let rec ast_type_to_ir_type = function
719717
| LsmContext -> IRContext LsmCtx
720718
| CgroupSkbContext -> IRContext CgroupSkbCtx
721719
| Xdp_action -> IRAction Xdp_actionType
722-
| UserType name -> IRStruct (name, [], false) (* Resolved by type checker *)
720+
| UserType name -> IRStruct (name, []) (* Resolved by type checker *)
723721
| Function (param_types, return_type) ->
724722
(* Function types are represented as proper function pointers *)
725723
let ir_param_types = List.map ast_type_to_ir_type param_types in
@@ -747,17 +745,17 @@ let rec ast_type_to_ir_type_with_context symbol_table ast_type =
747745
| Symbol_table.TypeDef (Ast.TypeAlias (_, underlying_type)) ->
748746
(* Create IRTypeAlias to preserve the alias name *)
749747
IRTypeAlias (name, ast_type_to_ir_type underlying_type)
750-
| Symbol_table.TypeDef (Ast.StructDef (_, fields, kernel_defined)) ->
748+
| Symbol_table.TypeDef (Ast.StructDef (_, fields)) ->
751749
(* Resolve struct fields properly with type aliases preserved *)
752750
let ir_fields = List.map (fun (field_name, field_type) ->
753751
(field_name, ast_type_to_ir_type_with_context symbol_table field_type)
754752
) fields in
755-
IRStruct (name, ir_fields, kernel_defined)
756-
| Symbol_table.TypeDef (Ast.EnumDef (_, values, kernel_defined)) ->
753+
IRStruct (name, ir_fields)
754+
| Symbol_table.TypeDef (Ast.EnumDef (_, values)) ->
757755
let ir_values = List.map (fun (enum_name, opt_value) ->
758756
(enum_name, Option.value ~default:0 opt_value)
759757
) values in
760-
IREnum (name, ir_values, kernel_defined)
758+
IREnum (name, ir_values)
761759
| _ -> ast_type_to_ir_type ast_type)
762760
| None ->
763761
(* Fallback to regular conversion *)
@@ -770,17 +768,17 @@ let rec ast_type_to_ir_type_with_context symbol_table ast_type =
770768
| Symbol_table.TypeDef (Ast.TypeAlias (_, underlying_type)) ->
771769
(* Create IRTypeAlias to preserve the alias name *)
772770
IRTypeAlias (name, ast_type_to_ir_type underlying_type)
773-
| Symbol_table.TypeDef (Ast.StructDef (_, fields, kernel_defined)) ->
771+
| Symbol_table.TypeDef (Ast.StructDef (_, fields)) ->
774772
(* Resolve struct fields properly with type aliases preserved *)
775773
let ir_fields = List.map (fun (field_name, field_type) ->
776774
(field_name, ast_type_to_ir_type_with_context symbol_table field_type)
777775
) fields in
778-
IRStruct (name, ir_fields, kernel_defined)
779-
| Symbol_table.TypeDef (Ast.EnumDef (_, values, kernel_defined)) ->
776+
IRStruct (name, ir_fields)
777+
| Symbol_table.TypeDef (Ast.EnumDef (_, values)) ->
780778
let ir_values = List.map (fun (enum_name, opt_value) ->
781779
(enum_name, Option.value ~default:0 opt_value)
782780
) values in
783-
IREnum (name, ir_values, kernel_defined)
781+
IREnum (name, ir_values)
784782
| _ -> ast_type_to_ir_type ast_type)
785783
| None ->
786784
(* Fallback to regular conversion *)
@@ -798,11 +796,11 @@ let rec ast_type_to_ir_type_with_context symbol_table ast_type =
798796
(match Symbol_table.lookup_symbol symbol_table name with
799797
| Some symbol ->
800798
(match symbol.kind with
801-
| Symbol_table.TypeDef (Ast.EnumDef (_, values, kernel_defined)) ->
799+
| Symbol_table.TypeDef (Ast.EnumDef (_, values)) ->
802800
let ir_values = List.map (fun (enum_name, opt_value) ->
803801
(enum_name, Option.value ~default:0 opt_value)
804802
) values in
805-
IREnum (name, ir_values, kernel_defined)
803+
IREnum (name, ir_values)
806804
| _ -> ast_type_to_ir_type ast_type)
807805
| None -> ast_type_to_ir_type ast_type)
808806
| _ -> ast_type_to_ir_type ast_type
@@ -835,8 +833,8 @@ let rec string_of_ir_type = function
835833
| IRStr size -> Printf.sprintf "str<%d>" size
836834
| IRPointer (t, _) -> Printf.sprintf "*%s" (string_of_ir_type t)
837835
| IRArray (t, size, _) -> Printf.sprintf "[%s; %d]" (string_of_ir_type t) size
838-
| IRStruct (name, _, _) -> Printf.sprintf "struct %s" name
839-
| IREnum (name, _, _) -> Printf.sprintf "enum %s" name
836+
| IRStruct (name, _) -> Printf.sprintf "struct %s" name
837+
| IREnum (name, _) -> Printf.sprintf "enum %s" name
840838
| IRResult (t1, t2) -> Printf.sprintf "result (%s, %s)" (string_of_ir_type t1) (string_of_ir_type t2)
841839
| IRTypeAlias (name, _) -> Printf.sprintf "type %s" name
842840
| IRStructOps (name, _) -> Printf.sprintf "struct_ops %s" name

0 commit comments

Comments
 (0)