-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathformat.py
More file actions
421 lines (348 loc) · 13.7 KB
/
format.py
File metadata and controls
421 lines (348 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import tree_sitter_cfengine as tscfengine
from tree_sitter import Language, Parser, Node
from cfbs.pretty import pretty_file
def format_json_file(filename):
assert filename.endswith(".json")
r = pretty_file(filename)
if r:
print(f"JSON file '{filename}' was reformatted")
def text(node: Node):
if not node.text:
return ""
return node.text.decode("utf-8")
class Formatter:
def __init__(self):
self.empty = True
self.previous = None
self.buffer = ""
def _write(self, message, end="\n"):
# print(message, end=end)
self.buffer += message + end
def print_lines(self, lines, indent):
for line in lines:
self.print(line, indent)
def print(self, string, indent):
if type(string) is not str:
string = text(string)
if not self.empty:
self._write("\n", end="")
self._write(" " * indent + string, end="")
self.empty = False
def print_same_line(self, string):
if type(string) is not str:
string = text(string)
self._write(string, end="")
def update_previous(self, node):
tmp = self.previous
self.previous = node
return tmp
def stringify_parameter_list(parts):
"""Join pre-extracted string tokens into a formatted parameter list.
Used when formatting bundle/body headers. Comments are
stripped from the parameter_list node before this function is called,
so `parts` contains only the structural tokens: "(", identifiers, ","
separators, and ")". The function removes any trailing comma before
")", then joins the tokens with appropriate spacing (space after each
comma, no space after "(" or before ")").
Example: ["(", "a", ",", "b", ",", ")"] -> "(a, b)"
"""
# Remove trailing comma before closing paren
cleaned = []
for i, part in enumerate(parts):
if part == "," and i + 1 < len(parts) and parts[i + 1] == ")":
continue
cleaned.append(part)
result = ""
previous = None
for part in cleaned:
if previous and previous != "(" and part != "," and part != ")":
result += " "
elif previous == ",":
result += " "
result += part
previous = part
return result
def stringify_single_line_nodes(nodes):
"""Join a list of tree-sitter nodes into a single-line string.
Operates on the direct child nodes of a CFEngine syntax construct
(e.g. a list, call, or attribute). Each child is recursively
flattened via stringify_single_line_node(). Spacing rules:
- A space is inserted after each "," separator.
- A space is inserted before and after "=>" (fat arrow).
- No extra space otherwise (e.g. no space after "(" or before ")").
Used by stringify_single_line_node() to recursively flatten any node with
children, and by maybe_split_generic_list() to attempt a single-line
rendering before falling back to multi-line splitting.
"""
result = ""
previous = None
for node in nodes:
string = stringify_single_line_node(node)
if previous and previous.type == ",":
result += " "
if previous and node.type == "=>":
result += " "
if previous and previous.type == "=>":
result += " "
result += string
previous = node
return result
def stringify_single_line_node(node):
if not node.children:
return text(node)
return stringify_single_line_nodes(node.children)
def split_generic_value(node, indent, line_length):
if node.type == "call":
return split_rval_call(node, indent, line_length)
if node.type == "list":
return split_rval_list(node, indent, line_length)
return [stringify_single_line_node(node)]
def split_generic_list(middle, indent, line_length):
elements = []
for element in middle:
if elements and element.type == ",":
elements[-1] = elements[-1] + ","
continue
line = " " * indent + stringify_single_line_node(element)
if len(line) < line_length:
elements.append(line)
else:
lines = split_generic_value(element, indent, line_length)
elements.append(" " * indent + lines[0])
elements.extend(lines[1:])
return elements
def maybe_split_generic_list(nodes, indent, line_length):
string = " " * indent + stringify_single_line_nodes(nodes)
if len(string) < line_length:
return [string]
return split_generic_list(nodes, indent, line_length)
def split_rval_list(node, indent, line_length):
assert node.type == "list"
assert node.children[0].type == "{"
first = text(node.children[0])
last = " " * indent + text(node.children[-1])
middle = node.children[1:-1]
elements = maybe_split_generic_list(middle, indent + 2, line_length)
return [first, *elements, last]
def split_rval_call(node, indent, line_length):
assert node.type == "call"
assert node.children[0].type == "calling_identifier"
assert node.children[1].type == "("
first = text(node.children[0]) + "("
last = " " * indent + text(node.children[-1])
middle = node.children[2:-1]
elements = maybe_split_generic_list(middle, indent + 2, line_length)
return [first, *elements, last]
def split_rval(node, indent, line_length):
if node.type == "list":
return split_rval_list(node, indent, line_length)
if node.type == "call":
return split_rval_call(node, indent, line_length)
return [stringify_single_line_node(node)]
def maybe_split_rval(node, indent, offset, line_length):
line = stringify_single_line_node(node)
if len(line) + offset < line_length:
return [line]
return split_rval(node, indent, line_length)
def attempt_split_attribute(node, indent, line_length):
assert len(node.children) == 3
lval = node.children[0]
arrow = node.children[1]
rval = node.children[2]
if rval.type == "list" or rval.type == "call":
prefix = " " * indent + text(lval) + " " + text(arrow) + " "
offset = len(prefix)
lines = maybe_split_rval(rval, indent, offset, line_length)
lines[0] = prefix + lines[0]
return lines
return [" " * indent + stringify_single_line_node(node)]
def stringify(node, indent, line_length):
single_line = " " * indent + stringify_single_line_node(node)
# Reserve 1 char for trailing ; or , after attributes
effective_length = line_length - 1 if node.type == "attribute" else line_length
if len(single_line) < effective_length:
return [single_line]
if node.type == "attribute":
return attempt_split_attribute(node, indent, line_length - 1)
return [single_line]
INDENTED_TYPES = {
"bundle_section",
"class_guarded_promises",
"class_guarded_body_attributes",
"class_guarded_promise_block_attributes",
"promise",
"half_promise",
"attribute",
}
CLASS_GUARD_TYPES = {
"class_guarded_promises",
"class_guarded_body_attributes",
"class_guarded_promise_block_attributes",
}
BLOCK_TYPES = {"bundle_block", "promise_block", "body_block"}
def can_single_line_promise(node, indent, line_length):
"""Check if a promise node can be formatted on a single line."""
if node.type != "promise":
return False
children = node.children
attr_children = [c for c in children if c.type == "attribute"]
next_sib = node.next_named_sibling
has_continuation = next_sib and next_sib.type == "half_promise"
if len(attr_children) != 1 or has_continuation:
return False
promiser_node = next((c for c in children if c.type == "promiser"), None)
if not promiser_node:
return False
line = (
text(promiser_node) + " " + stringify_single_line_node(attr_children[0]) + ";"
)
return indent + len(line) <= line_length
def format_block_header(node, fmt):
"""Format the header of a bundle/body block and return its body children."""
header_parts = []
header_comments = []
for x in node.children[0:-1]:
if x.type == "comment":
header_comments.append(text(x))
elif x.type == "parameter_list":
parts = []
for p in x.children:
if p.type == "comment":
header_comments.append(text(p))
else:
parts.append(text(p))
header_parts[-1] = header_parts[-1] + stringify_parameter_list(parts)
else:
header_parts.append(text(x))
line = " ".join(header_parts)
if not fmt.empty:
prev_sib = node.prev_named_sibling
if not (prev_sib and prev_sib.type == "comment"):
fmt.print("", 0)
fmt.print(line, 0)
for i, comment in enumerate(header_comments):
if comment.strip() == "#":
prev_is_comment = i > 0 and header_comments[i - 1].strip() != "#"
next_is_comment = (
i + 1 < len(header_comments) and header_comments[i + 1].strip() != "#"
)
if not (prev_is_comment and next_is_comment):
continue
fmt.print(comment, 0)
return node.children[-1].children
def needs_blank_line_before(child, indent, line_length):
"""Check if a blank line should be inserted before this child node."""
prev = child.prev_named_sibling
if not prev:
return False
if child.type == "bundle_section":
return prev.type == "bundle_section"
if child.type == "promise" and prev.type in {"promise", "half_promise"}:
promise_indent = indent + 2
both_single = (
prev.type == "promise"
and can_single_line_promise(prev, promise_indent, line_length)
and can_single_line_promise(child, promise_indent, line_length)
)
return not both_single
if child.type in CLASS_GUARD_TYPES:
return prev.type in {"promise", "half_promise", "class_guarded_promises"}
if child.type == "comment":
if prev.type in {"promise", "half_promise"} | CLASS_GUARD_TYPES:
parent = child.parent
return parent and parent.type in {"bundle_section"} | {
"class_guarded_promises"
}
return False
return False
def get_comment_indent(node, indent):
"""Determine the indentation level for a comment node."""
next_sib = node.next_named_sibling
while next_sib and next_sib.type == "comment":
next_sib = next_sib.next_named_sibling
if next_sib is None:
prev_sib = node.prev_named_sibling
while prev_sib and prev_sib.type == "comment":
prev_sib = prev_sib.prev_named_sibling
if prev_sib and prev_sib.type in INDENTED_TYPES:
return indent + 2
elif next_sib.type in INDENTED_TYPES:
return indent + 2
return indent
def is_empty_comment(node):
"""Check if a comment is just '#' with no content."""
if text(node).strip() != "#":
return False
prev = node.prev_named_sibling
nxt = node.next_named_sibling
return not (prev and prev.type == "comment" and nxt and nxt.type == "comment")
def autoformat(node, fmt, line_length, macro_indent, indent=0):
previous = fmt.update_previous(node)
if previous and previous.type == "macro" and text(previous).startswith("@else"):
indent = macro_indent
if node.type == "macro":
fmt.print(node, 0)
if text(node).startswith("@if"):
macro_indent = indent
elif text(node).startswith("@else"):
indent = macro_indent
return
children = node.children
if node.type in BLOCK_TYPES:
children = format_block_header(node, fmt)
if node.type in INDENTED_TYPES:
indent += 2
if node.type == "attribute":
lines = stringify(node, indent, line_length)
fmt.print_lines(lines, indent=0)
return
if node.type == "promise" and can_single_line_promise(node, indent, line_length):
promiser_node = next(c for c in children if c.type == "promiser")
attr_node = next(c for c in children if c.type == "attribute")
line = text(promiser_node) + " " + stringify_single_line_node(attr_node) + ";"
fmt.print(line, indent)
return
if children:
for child in children:
if needs_blank_line_before(child, indent, line_length):
fmt.print("", 0)
autoformat(child, fmt, line_length, macro_indent, indent)
return
if node.type in {",", ";"}:
fmt.print_same_line(node)
return
if node.type == "comment":
if is_empty_comment(node):
return
fmt.print(node, get_comment_indent(node, indent))
return
fmt.print(node, indent)
def format_policy_file(filename, line_length):
assert filename.endswith(".cf")
PY_LANGUAGE = Language(tscfengine.language())
parser = Parser(PY_LANGUAGE)
macro_indent = 0
fmt = Formatter()
with open(filename, "rb") as f:
original_data = f.read()
tree = parser.parse(original_data)
root_node = tree.root_node
assert root_node.type == "source_file"
autoformat(root_node, fmt, line_length, macro_indent)
new_data = fmt.buffer + "\n"
if new_data != original_data.decode("utf-8"):
with open(filename, "w") as f:
f.write(new_data)
print(f"Policy file '{filename}' was reformatted")
def format_policy_fin_fout(fin, fout, line_length):
PY_LANGUAGE = Language(tscfengine.language())
parser = Parser(PY_LANGUAGE)
macro_indent = 0
fmt = Formatter()
original_data = fin.read().encode("utf-8")
tree = parser.parse(original_data)
root_node = tree.root_node
assert root_node.type == "source_file"
autoformat(root_node, fmt, line_length, macro_indent)
new_data = fmt.buffer + "\n"
fout.write(new_data)