diff --git a/src/cfengine_cli/commands.py b/src/cfengine_cli/commands.py index d8d1afe..ec3aec9 100644 --- a/src/cfengine_cli/commands.py +++ b/src/cfengine_cli/commands.py @@ -50,7 +50,7 @@ def deploy() -> int: return r -def _format_filename(filename, line_length, check): +def _format_filename(filename: str, line_length: int, check: bool) -> int: if filename.startswith("./."): return 0 if filename.endswith(".json"): @@ -60,7 +60,7 @@ def _format_filename(filename, line_length, check): raise UserError(f"Unrecognized file format: {filename}") -def _format_dirname(directory, line_length, check): +def _format_dirname(directory: str, line_length: int, check: bool) -> int: ret = 0 for filename in find(directory, extension=".json"): ret |= _format_filename(filename, line_length, check) diff --git a/src/cfengine_cli/format.py b/src/cfengine_cli/format.py index 58615b9..4d260e4 100644 --- a/src/cfengine_cli/format.py +++ b/src/cfengine_cli/format.py @@ -1,72 +1,108 @@ +from __future__ import annotations + +from typing import IO + import tree_sitter_cfengine as tscfengine from tree_sitter import Language, Parser, Node from cfbs.pretty import pretty_file, pretty_check_file +# Node types that increase indentation by 2 when entered +INDENTED_TYPES = { + "bundle_section", + "class_guarded_promises", + "class_guarded_body_attributes", + "class_guarded_promise_block_attributes", + "promise", + "half_promise", + "attribute", +} + +CLASS_GUARD_TYPES = { + "class_guarded_promises", + "class_guarded_body_attributes", + "class_guarded_promise_block_attributes", +} + +BLOCK_TYPES = {"bundle_block", "promise_block", "body_block"} -def format_json_file(filename, check): +PROMISER_PARTS = {"promiser", "->", "stakeholder"} + + +def format_json_file(filename: str, check: bool) -> int: + """Reformat a JSON file in place using cfbs pretty-printer.""" assert filename.endswith(".json") if check: r = not pretty_check_file(filename) if r: print(f"JSON file '{filename}' needs reformatting") - return r + return int(r) r = pretty_file(filename) if r: print(f"JSON file '{filename}' was reformatted") - return r + return int(r) -def text(node: Node): +def text(node: Node) -> str: + """Extract the UTF-8 text content of a tree-sitter node.""" if not node.text: return "" return node.text.decode("utf-8") class Formatter: - def __init__(self): - self.empty = True - self.previous = None - self.buffer = "" + """Accumulates formatted output line-by-line into a string buffer.""" + + def __init__(self) -> None: + self.empty: bool = True + self.previous: Node | None = None + self.buffer: str = "" - def _write(self, message, end="\n"): - # print(message, end=end) + def _write(self, message: str, end: str = "\n") -> None: + """Append text to the buffer with the given line ending.""" self.buffer += message + end - def print_lines(self, lines, indent): + def print_lines(self, lines: list[str], indent: int) -> None: + """Print multiple pre-formatted lines.""" for line in lines: self.print(line, indent) - def print(self, string, indent): - if type(string) is not str: + def print(self, string: str | Node, indent: int) -> None: + """Print a string or node on a new line with the given indentation.""" + if not isinstance(string, str): string = text(string) if not self.empty: self._write("\n", end="") self._write(" " * indent + string, end="") self.empty = False - def print_same_line(self, string): - if type(string) is not str: + def print_same_line(self, string: str | Node) -> None: + """Append text to the current line without a preceding newline.""" + if not isinstance(string, str): string = text(string) self._write(string, end="") - def update_previous(self, node): + def blank_line(self) -> None: + """Insert a blank separator line.""" + self.print("", 0) + + def update_previous(self, node: Node) -> Node | None: + """Set the previously-visited node, returning the old value.""" tmp = self.previous self.previous = node return tmp -def stringify_parameter_list(parts): - """Join pre-extracted string tokens into a formatted parameter list. +# --------------------------------------------------------------------------- +# Stringify helpers — flatten nodes into single-line strings +# --------------------------------------------------------------------------- + - Used when formatting bundle/body headers. Comments are - stripped from the parameter_list node before this function is called, - so `parts` contains only the structural tokens: "(", identifiers, "," - separators, and ")". The function removes any trailing comma before - ")", then joins the tokens with appropriate spacing (space after each - comma, no space after "(" or before ")"). +def stringify_parameter_list(parts: list[str]) -> str: + """Join string tokens into a formatted parameter list. + Removes trailing commas and adds spacing after commas. Example: ["(", "a", ",", "b", ",", ")"] -> "(a, b)" """ # Remove trailing comma before closing paren @@ -87,19 +123,10 @@ def stringify_parameter_list(parts): return result -def stringify_single_line_nodes(nodes): - """Join a list of tree-sitter nodes into a single-line string. - - Operates on the direct child nodes of a CFEngine syntax construct - (e.g. a list, call, or attribute). Each child is recursively - flattened via stringify_single_line_node(). Spacing rules: - - A space is inserted after each "," separator. - - A space is inserted before and after "=>" (fat arrow). - - No extra space otherwise (e.g. no space after "(" or before ")"). +def stringify_single_line_nodes(nodes: list[Node]) -> str: + """Join tree-sitter nodes into a single-line string with CFEngine spacing. - Used by stringify_single_line_node() to recursively flatten any node with - children, and by maybe_split_generic_list() to attempt a single-line - rendering before falling back to multi-line splitting. + Inserts spaces after ",", around "=>", and inside "{ }". """ result = "" previous = None @@ -120,13 +147,20 @@ def stringify_single_line_nodes(nodes): return result -def stringify_single_line_node(node): +def stringify_single_line_node(node: Node) -> str: + """Recursively flatten a node and its children into a single-line string.""" if not node.children: return text(node) return stringify_single_line_nodes(node.children) -def split_generic_value(node, indent, line_length): +# --------------------------------------------------------------------------- +# List / rval splitting — multi-line formatting for long values +# --------------------------------------------------------------------------- + + +def split_generic_value(node: Node, indent: int, line_length: int) -> list[str]: + """Split a value node (call, list, or other) into multi-line strings.""" if node.type == "call": return split_rval_call(node, indent, line_length) if node.type == "list": @@ -134,8 +168,9 @@ def split_generic_value(node, indent, line_length): return [stringify_single_line_node(node)] -def split_generic_list(middle, indent, line_length): - elements = [] +def split_generic_list(middle: list[Node], indent: int, line_length: int) -> list[str]: + """Split list elements into one-per-line strings, each pre-indented.""" + elements: list[str] = [] for element in middle: if elements and element.type == ",": elements[-1] = elements[-1] + "," @@ -150,14 +185,18 @@ def split_generic_list(middle, indent, line_length): return elements -def maybe_split_generic_list(nodes, indent, line_length): +def maybe_split_generic_list( + nodes: list[Node], indent: int, line_length: int +) -> list[str]: + """Try a single-line rendering; fall back to split_generic_list if too long.""" string = " " * indent + stringify_single_line_nodes(nodes) if len(string) < line_length: return [string] return split_generic_list(nodes, indent, line_length) -def split_rval_list(node, indent, line_length): +def split_rval_list(node: Node, indent: int, line_length: int) -> list[str]: + """Split a list rval ({ ... }) into multi-line form.""" assert node.type == "list" assert node.children[0].type == "{" first = text(node.children[0]) @@ -167,7 +206,8 @@ def split_rval_list(node, indent, line_length): return [first, *elements, last] -def split_rval_call(node, indent, line_length): +def split_rval_call(node: Node, indent: int, line_length: int) -> list[str]: + """Split a function call rval (func(...)) into multi-line form.""" assert node.type == "call" assert node.children[0].type == "calling_identifier" assert node.children[1].type == "(" @@ -178,7 +218,8 @@ def split_rval_call(node, indent, line_length): return [first, *elements, last] -def split_rval(node, indent, line_length): +def split_rval(node: Node, indent: int, line_length: int) -> list[str]: + """Split an rval node into multi-line form based on its type.""" if node.type == "list": return split_rval_list(node, indent, line_length) if node.type == "call": @@ -186,14 +227,23 @@ def split_rval(node, indent, line_length): return [stringify_single_line_node(node)] -def maybe_split_rval(node, indent, offset, line_length): +def maybe_split_rval( + node: Node, indent: int, offset: int, line_length: int +) -> list[str]: + """Return single-line rval if it fits at offset, otherwise split it.""" line = stringify_single_line_node(node) if len(line) + offset < line_length: return [line] return split_rval(node, indent, line_length) -def attempt_split_attribute(node, indent, line_length): +# --------------------------------------------------------------------------- +# Attribute formatting +# --------------------------------------------------------------------------- + + +def attempt_split_attribute(node: Node, indent: int, line_length: int) -> list[str]: + """Split an attribute node, wrapping the rval if it's a list or call.""" assert len(node.children) == 3 lval = node.children[0] arrow = node.children[1] @@ -208,7 +258,8 @@ def attempt_split_attribute(node, indent, line_length): return [" " * indent + stringify_single_line_node(node)] -def stringify(node, indent, line_length): +def stringify(node: Node, indent: int, line_length: int) -> list[str]: + """Return a node as pre-indented line(s), splitting if it exceeds line_length.""" single_line = " " * indent + stringify_single_line_node(node) # Reserve 1 char for trailing ; or , after attributes effective_length = line_length - 1 if node.type == "attribute" else line_length @@ -219,31 +270,48 @@ def stringify(node, indent, line_length): return [single_line] -def has_stakeholder(children): - return any(c.type == "stakeholder" for c in children) +# --------------------------------------------------------------------------- +# Stakeholder helpers +# --------------------------------------------------------------------------- -def stakeholder_has_comments(children): +def _get_stakeholder_list(children: list[Node]) -> Node | None: + """Return the list node inside a promise's stakeholder, or None.""" stakeholder = next((c for c in children if c.type == "stakeholder"), None) if not stakeholder: + return None + return next((c for c in stakeholder.children if c.type == "list"), None) + + +def _stakeholder_has_comments(children: list[Node]) -> bool: + """Check if the stakeholder list contains any comment nodes.""" + list_node = _get_stakeholder_list(children) + if not list_node: return False - for child in stakeholder.children: - if child.type == "list": - return any(c.type == "comment" for c in child.children) + return any(c.type == "comment" for c in list_node.children) + + +def _has_trailing_comma(middle: list[Node]) -> bool: + """Check if a list's middle nodes end with a trailing comma.""" + for node in reversed(middle): + if node.type == ",": + return True + if node.type != "comment": + return False return False -def promiser_prefix(children): - """Build the promiser text (without stakeholder).""" +def _promiser_text(children: list[Node]) -> str | None: + """Return the raw promiser string from promise children, or None.""" promiser_node = next((c for c in children if c.type == "promiser"), None) if not promiser_node: return None return text(promiser_node) -def promiser_line(children): - """Build the promiser prefix: promiser + optional '-> stakeholder'.""" - prefix = promiser_prefix(children) +def _promiser_line_with_stakeholder(children: list[Node]) -> str | None: + """Build the full promiser line including '-> { stakeholder }', or None.""" + prefix = _promiser_text(children) if not prefix: return None arrow = next((c for c in children if c.type == "->"), None) @@ -253,58 +321,32 @@ def promiser_line(children): return prefix -def stakeholder_needs_splitting(children, indent, line_length): - """Check if the stakeholder list needs to be split across multiple lines.""" - if stakeholder_has_comments(children): +def _stakeholder_needs_splitting( + children: list[Node], indent: int, line_length: int +) -> bool: + """Check if the stakeholder list must be split (comments or too long).""" + if _stakeholder_has_comments(children): return True - prefix = promiser_line(children) - if not prefix: + line = _promiser_line_with_stakeholder(children) + if not line: return False - return indent + len(prefix) > line_length + return indent + len(line) > line_length -def split_stakeholder(children, indent, has_attributes, line_length): - """Split a stakeholder list across multiple lines. +def _format_stakeholder_elements( + middle: list[Node], indent: int, line_length: int +) -> list[str]: + """Format the elements between { and } of a stakeholder list. - Returns (opening_line, element_lines, closing_str) where: - - opening_line: 'promiser -> {' to print at promise indent - - element_lines: pre-indented element strings - - closing_str: '}' or '};' pre-indented at the appropriate level + Uses trailing-comma heuristic: lists with a trailing comma are always + split one-per-line; without, they may stay on a single line. """ - prefix = promiser_prefix(children) - assert prefix is not None - opening = prefix + " -> {" - stakeholder = next(c for c in children if c.type == "stakeholder") - list_node = next(c for c in stakeholder.children if c.type == "list") - middle = list_node.children[1:-1] # between { and } - element_indent = indent + 4 - has_comments = stakeholder_has_comments(children) - if has_attributes or has_comments: - close_indent = indent + 2 - else: - close_indent = indent - elements = format_stakeholder_elements(middle, element_indent, line_length) - return opening, elements, close_indent - - -def has_trailing_comma(middle): - """Check if a list's middle nodes end with a trailing comma.""" - for node in reversed(middle): - if node.type == ",": - return True - if node.type != "comment": - return False - return False - - -def format_stakeholder_elements(middle, indent, line_length): - """Format the middle elements of a stakeholder list.""" - has_comments = any(n.type == "comment" for n in middle) - if not has_comments: - if has_trailing_comma(middle): + if not any(n.type == "comment" for n in middle): + if _has_trailing_comma(middle): return split_generic_list(middle, indent, line_length) return maybe_split_generic_list(middle, indent, line_length) - elements = [] + # Comments present — format element-by-element to preserve them + elements: list[str] = [] for node in middle: if node.type == ",": if elements: @@ -323,36 +365,243 @@ def format_stakeholder_elements(middle, indent, line_length): return elements -def can_single_line_promise(node, indent, line_length): - """Check if a promise node can be formatted on a single line.""" +# --------------------------------------------------------------------------- +# Promise formatting +# --------------------------------------------------------------------------- + + +def _has_stakeholder(children: list[Node]) -> bool: + """Check if promise children include a stakeholder node.""" + return any(c.type == "stakeholder" for c in children) + + +def can_single_line_promise(node: Node, indent: int, line_length: int) -> bool: + """Check if a promise can be formatted entirely on one line. + + Returns False for multi-attribute promises, promises with a + half_promise continuation, or stakeholder+attribute combinations. + """ if node.type != "promise": return False children = node.children - attr_children = [c for c in children if c.type == "attribute"] + attrs = [c for c in children if c.type == "attribute"] next_sib = node.next_named_sibling - has_continuation = next_sib and next_sib.type == "half_promise" - if len(attr_children) > 1 or has_continuation: + if len(attrs) > 1 or (next_sib and next_sib.type == "half_promise"): return False - # Promises with stakeholder + attributes are always multi-line - if has_stakeholder(children) and attr_children: + if _has_stakeholder(children) and attrs: return False - # Stakeholders that need splitting can't be single-lined - if has_stakeholder(children) and stakeholder_needs_splitting( + if _has_stakeholder(children) and _stakeholder_needs_splitting( children, indent, line_length ): return False - prefix = promiser_line(children) - if not prefix: + line = _promiser_line_with_stakeholder(children) + if not line: return False - if attr_children: - line = prefix + " " + stringify_single_line_node(attr_children[0]) + ";" + if attrs: + line += " " + stringify_single_line_node(attrs[0]) + ";" else: - line = prefix + ";" + line += ";" return indent + len(line) <= line_length -def autoformat(node, fmt, line_length, macro_indent, indent=0): +def _format_promise( + node: Node, + children: list[Node], + fmt: Formatter, + indent: int, + line_length: int, + macro_indent: int, +) -> bool: + """Format a promise node. Returns True if handled, False to fall through.""" + # Single-line promise + if can_single_line_promise(node, indent, line_length): + prefix = _promiser_line_with_stakeholder(children) + assert prefix is not None + attr = next((c for c in children if c.type == "attribute"), None) + if attr: + line = prefix + " " + stringify_single_line_node(attr) + ";" + else: + line = prefix + ";" + fmt.print(line, indent) + return True + + # Multi-line with split stakeholder + if _has_stakeholder(children) and _stakeholder_needs_splitting( + children, indent, line_length + ): + attrs = [c for c in children if c.type == "attribute"] + promiser = _promiser_text(children) + assert promiser is not None + fmt.print(promiser + " -> {", indent) + + list_node = _get_stakeholder_list(children) + assert list_node is not None + middle = list_node.children[1:-1] + element_indent = indent + 4 + elements = _format_stakeholder_elements(middle, element_indent, line_length) + fmt.print_lines(elements, indent=0) + + has_comments = _stakeholder_has_comments(children) + close_indent = indent + 2 if (attrs or has_comments) else indent + if attrs: + fmt.print("}", close_indent) + _format_remaining_children(children, fmt, indent, line_length, macro_indent) + else: + fmt.print("};", close_indent) + return True + + # Multi-line with inline stakeholder + prefix = _promiser_line_with_stakeholder(children) + if prefix: + fmt.print(prefix, indent) + _format_remaining_children(children, fmt, indent, line_length, macro_indent) + return True + + return False + + +def _format_remaining_children( + children: list[Node], + fmt: Formatter, + indent: int, + line_length: int, + macro_indent: int, +) -> None: + """Format promise children, skipping promiser/arrow/stakeholder parts.""" + for child in children: + if child.type in PROMISER_PARTS: + continue + autoformat(child, fmt, line_length, macro_indent, indent) + + +# --------------------------------------------------------------------------- +# Block header formatting (bundle, body, promise blocks) +# --------------------------------------------------------------------------- + + +def _format_block_header(node: Node, fmt: Formatter) -> list[Node]: + """Format a block header line and return the body's children for further processing.""" + header_parts: list[str] = [] + header_comments: list[str] = [] + for x in node.children[0:-1]: + if x.type == "comment": + header_comments.append(text(x)) + elif x.type == "parameter_list": + parts: list[str] = [] + for p in x.children: + if p.type == "comment": + header_comments.append(text(p)) + else: + parts.append(text(p)) + header_parts[-1] = header_parts[-1] + stringify_parameter_list(parts) + else: + header_parts.append(text(x)) + line = " ".join(header_parts) + if not fmt.empty: + prev_sib = node.prev_named_sibling + if not (prev_sib and prev_sib.type == "comment"): + fmt.blank_line() + fmt.print(line, 0) + for i, comment in enumerate(header_comments): + if comment.strip() == "#": + prev_is_comment = i > 0 and header_comments[i - 1].strip() != "#" + next_is_comment = ( + i + 1 < len(header_comments) and header_comments[i + 1].strip() != "#" + ) + if not (prev_is_comment and next_is_comment): + continue + fmt.print(comment, 0) + return node.children[-1].children + + +# --------------------------------------------------------------------------- +# Blank line logic +# --------------------------------------------------------------------------- + + +def _needs_blank_line_before(child: Node, indent: int, line_length: int) -> bool: + """Check if a blank separator line should precede this child node.""" + prev = child.prev_named_sibling + if not prev: + return False + + if child.type == "bundle_section": + return prev.type == "bundle_section" + + if child.type == "promise" and prev.type in {"promise", "half_promise"}: + promise_indent = indent + 2 + both_single = ( + prev.type == "promise" + and can_single_line_promise(prev, promise_indent, line_length) + and can_single_line_promise(child, promise_indent, line_length) + ) + return not both_single + + if child.type in CLASS_GUARD_TYPES: + return prev.type in {"promise", "half_promise", "class_guarded_promises"} + + if child.type == "comment": + if prev.type not in {"promise", "half_promise"} | CLASS_GUARD_TYPES: + return False + parent = child.parent + return bool( + parent and parent.type in {"bundle_section", "class_guarded_promises"} + ) + + return False + + +# --------------------------------------------------------------------------- +# Comment formatting +# --------------------------------------------------------------------------- + + +def _is_empty_comment(node: Node) -> bool: + """Check if a bare '#' comment should be dropped (not between other comments).""" + if text(node).strip() != "#": + return False + prev = node.prev_named_sibling + nxt = node.next_named_sibling + return not (prev and prev.type == "comment" and nxt and nxt.type == "comment") + + +def _skip_comments(sibling: Node | None, direction: str = "next") -> Node | None: + """Walk past adjacent comment siblings to find the nearest non-comment.""" + while sibling and sibling.type == "comment": + sibling = ( + sibling.next_named_sibling + if direction == "next" + else sibling.prev_named_sibling + ) + return sibling + + +def _comment_indent(node: Node, indent: int) -> int: + """Compute indentation for a leaf comment based on its nearest non-comment neighbor.""" + nearest = _skip_comments(node.next_named_sibling, "next") + if nearest is None: + nearest = _skip_comments(node.prev_named_sibling, "prev") + if nearest and nearest.type in INDENTED_TYPES: + return indent + 2 + return indent + + +# --------------------------------------------------------------------------- +# Main recursive formatter +# --------------------------------------------------------------------------- + + +def autoformat( + node: Node, + fmt: Formatter, + line_length: int, + macro_indent: int, + indent: int = 0, +) -> None: + """Recursively format a tree-sitter node tree into the Formatter buffer.""" previous = fmt.update_previous(node) + + # Macro handling if previous and previous.type == "macro" and text(previous).startswith("@else"): indent = macro_indent if node.type == "macro": @@ -362,187 +611,51 @@ def autoformat(node, fmt, line_length, macro_indent, indent=0): elif text(node).startswith("@else"): indent = macro_indent return + + # Block header (bundle/body/promise blocks) children = node.children - if node.type in ["bundle_block", "promise_block", "body_block"]: - header_parts = [] - header_comments = [] - for x in node.children[0:-1]: - if x.type == "comment": - header_comments.append(text(x)) - elif x.type == "parameter_list": - parts = [] - for p in x.children: - if p.type == "comment": - header_comments.append(text(p)) - else: - parts.append(text(p)) - # Append directly to previous part (no space before parens) - header_parts[-1] = header_parts[-1] + stringify_parameter_list(parts) - else: - header_parts.append(text(x)) - line = " ".join(header_parts) - if not fmt.empty: - prev_sib = node.prev_named_sibling - if not (prev_sib and prev_sib.type == "comment"): - fmt.print("", 0) - fmt.print(line, 0) - for i, comment in enumerate(header_comments): - if comment.strip() == "#": - prev_is_comment = i > 0 and header_comments[i - 1].strip() != "#" - next_is_comment = ( - i + 1 < len(header_comments) - and header_comments[i + 1].strip() != "#" - ) - if not (prev_is_comment and next_is_comment): - continue - fmt.print(comment, 0) - children = node.children[-1].children - if node.type in [ - "bundle_section", - "class_guarded_promises", - "class_guarded_body_attributes", - "class_guarded_promise_block_attributes", - "promise", - "half_promise", - "attribute", - ]: + if node.type in BLOCK_TYPES: + children = _format_block_header(node, fmt) + + # Indentation + if node.type in INDENTED_TYPES: indent += 2 + + # Attribute — stringify and return if node.type == "attribute": - lines = stringify(node, indent, line_length) - fmt.print_lines(lines, indent=0) + fmt.print_lines(stringify(node, indent, line_length), indent=0) return + + # Promise — delegate to promise formatter if node.type == "promise": - if can_single_line_promise(node, indent, line_length): - prefix = promiser_line(children) - assert prefix is not None - attr_node = next((c for c in children if c.type == "attribute"), None) - if attr_node: - line = prefix + " " + stringify_single_line_node(attr_node) + ";" - else: - line = prefix + ";" - fmt.print(line, indent) - return - # Multi-line promise with stakeholder that needs splitting - attr_children = [c for c in children if c.type == "attribute"] - if has_stakeholder(children) and stakeholder_needs_splitting( - children, indent, line_length - ): - opening, elements, close_indent = split_stakeholder( - children, indent, bool(attr_children), line_length - ) - fmt.print(opening, indent) - fmt.print_lines(elements, indent=0) - if attr_children: - fmt.print("}", close_indent) - else: - fmt.print("};", close_indent) - return - for child in children: - if child.type in {"promiser", "->", "stakeholder"}: - continue - autoformat(child, fmt, line_length, macro_indent, indent) - return - # Multi-line promise: print promiser (with stakeholder) then recurse for rest - prefix = promiser_line(children) - if prefix: - fmt.print(prefix, indent) - for child in children: - if child.type in {"promiser", "->", "stakeholder"}: - continue - autoformat(child, fmt, line_length, macro_indent, indent) + if _format_promise(node, children, fmt, indent, line_length, macro_indent): return + + # Interior node with children — recurse if children: for child in children: - # Blank line between bundle sections - if child.type == "bundle_section": - prev = child.prev_named_sibling - if prev and prev.type == "bundle_section": - fmt.print("", 0) - # Blank line between promises in a section - elif child.type == "promise": - prev = child.prev_named_sibling - if prev and prev.type in ["promise", "half_promise"]: - # Skip blank line between consecutive single-line promises - promise_indent = indent + 2 - both_single = ( - prev.type == "promise" - and can_single_line_promise(prev, promise_indent, line_length) - and can_single_line_promise(child, promise_indent, line_length) - ) - if not both_single: - fmt.print("", 0) - elif child.type in [ - "class_guarded_promises", - "class_guarded_body_attributes", - "class_guarded_promise_block_attributes", - ]: - prev = child.prev_named_sibling - if prev and prev.type in [ - "promise", - "half_promise", - "class_guarded_promises", - ]: - fmt.print("", 0) - elif child.type == "comment": - prev = child.prev_named_sibling - if prev and prev.type in [ - "promise", - "half_promise", - "class_guarded_promises", - "class_guarded_body_attributes", - "class_guarded_promise_block_attributes", - ]: - parent = child.parent - if parent and parent.type in [ - "bundle_section", - "class_guarded_promises", - ]: - fmt.print("", 0) + if _needs_blank_line_before(child, indent, line_length): + fmt.blank_line() autoformat(child, fmt, line_length, macro_indent, indent) return - if node.type in [",", ";"]: + + # Leaf nodes + if node.type in {",", ";"}: fmt.print_same_line(node) - return - if node.type == "comment": - if text(node).strip() == "#": - prev = node.prev_named_sibling - nxt = node.next_named_sibling - if not (prev and prev.type == "comment" and nxt and nxt.type == "comment"): - return - comment_indent = indent - next_sib = node.next_named_sibling - while next_sib and next_sib.type == "comment": - next_sib = next_sib.next_named_sibling - if next_sib is None: - prev_sib = node.prev_named_sibling - while prev_sib and prev_sib.type == "comment": - prev_sib = prev_sib.prev_named_sibling - if prev_sib and prev_sib.type in [ - "bundle_section", - "class_guarded_promises", - "class_guarded_body_attributes", - "class_guarded_promise_block_attributes", - "promise", - "half_promise", - "attribute", - ]: - comment_indent = indent + 2 - elif next_sib.type in [ - "bundle_section", - "class_guarded_promises", - "class_guarded_body_attributes", - "class_guarded_promise_block_attributes", - "promise", - "half_promise", - "attribute", - ]: - comment_indent = indent + 2 - fmt.print(node, comment_indent) - return - fmt.print(node, indent) + elif node.type == "comment": + if not _is_empty_comment(node): + fmt.print(node, _comment_indent(node, indent)) + else: + fmt.print(node, indent) + + +# --------------------------------------------------------------------------- +# Public entry points +# --------------------------------------------------------------------------- -def format_policy_file(filename, line_length, check): +def format_policy_file(filename: str, line_length: int, check: bool) -> int: + """Format a .cf policy file in place, writing only if content changed.""" assert filename.endswith(".cf") PY_LANGUAGE = Language(tscfengine.language()) @@ -570,7 +683,10 @@ def format_policy_file(filename, line_length, check): return 0 -def format_policy_fin_fout(fin, fout, line_length, check): +def format_policy_fin_fout( + fin: IO[str], fout: IO[str], line_length: int, check: bool +) -> int: + """Format CFEngine policy read from fin, writing the result to fout.""" PY_LANGUAGE = Language(tscfengine.language()) parser = Parser(PY_LANGUAGE) diff --git a/tests/shell/004-format-check.sh b/tests/shell/004-format-check.sh new file mode 100755 index 0000000..f05acc4 --- /dev/null +++ b/tests/shell/004-format-check.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e +set -x + +# Setup: create a temp directory for test files +tmpdir=$(mktemp -d) +trap "rm -rf $tmpdir" EXIT + +write_formatted() { + printf 'bundle agent main\n{\n vars:\n "v" string => "hello";\n}\n' > "$1" +} + +write_unformatted() { + printf 'bundle agent main { vars: "v" string => "hello"; }\n' > "$1" +} + +# Case 1: format without --check on already-formatted file -> exit 0 +write_formatted "$tmpdir/good.cf" +cfengine format "$tmpdir/good.cf" + +# Case 2: format without --check on unformatted file -> exit 0 (reformats it) +write_unformatted "$tmpdir/bad.cf" +cfengine format "$tmpdir/bad.cf" +# Verify it was actually reformatted to the correct output +write_formatted "$tmpdir/expected.cf" +diff "$tmpdir/expected.cf" "$tmpdir/bad.cf" + +# Case 3: --check on already-formatted file -> exit 0 +write_formatted "$tmpdir/good2.cf" +cfengine format --check "$tmpdir/good2.cf" + +# Case 4: --check on unformatted file -> exit 1 +write_unformatted "$tmpdir/bad2.cf" +cp "$tmpdir/bad2.cf" "$tmpdir/bad2_orig.cf" +if cfengine format --check "$tmpdir/bad2.cf"; then + echo "FAIL: expected exit code 1 for --check on unformatted file" + exit 1 +fi +# Verify the file was NOT modified +diff "$tmpdir/bad2_orig.cf" "$tmpdir/bad2.cf" diff --git a/tests/unit/test_format.py b/tests/unit/test_format.py index b1b91a0..71c93f7 100644 --- a/tests/unit/test_format.py +++ b/tests/unit/test_format.py @@ -1,19 +1,221 @@ -from cfengine_cli.format import stringify_parameter_list, stringify_single_line_nodes +import io + +import tree_sitter_cfengine as tscfengine +from tree_sitter import Language, Parser, Node + +from cfengine_cli.format import ( + Formatter, + text, + stringify_parameter_list, + stringify_single_line_nodes, + stringify_single_line_node, + split_generic_list, + maybe_split_generic_list, + split_rval_list, + split_rval_call, + split_rval, + maybe_split_rval, + split_generic_value, + attempt_split_attribute, + stringify, + can_single_line_promise, + autoformat, + format_policy_fin_fout, + _get_stakeholder_list, + _stakeholder_has_comments, + _has_trailing_comma, + _promiser_text, + _promiser_line_with_stakeholder, + _stakeholder_needs_splitting, + _format_stakeholder_elements, + _has_stakeholder, + _is_empty_comment, + _skip_comments, + _comment_indent, + _needs_blank_line_before, +) + +# --------------------------------------------------------------------------- +# MockNode — lightweight stand-in for tree-sitter Node +# --------------------------------------------------------------------------- class MockNode: - """Minimal stand-in for a tree-sitter Node used by stringify_single_line_nodes.""" + """Minimal stand-in for a tree-sitter Node.""" - def __init__(self, node_type, node_text=None, children=None): + def __init__( + self, + node_type, + node_text=None, + children=None, + next_named_sibling=None, + prev_named_sibling=None, + parent=None, + ): self.type = node_type self.text = node_text.encode("utf-8") if node_text is not None else None self.children = children or [] + self.next_named_sibling = next_named_sibling + self.prev_named_sibling = prev_named_sibling + self.parent = parent def _leaf(node_type, node_text=None): return MockNode(node_type, node_text or node_type) +# --------------------------------------------------------------------------- +# Real parser helper — parse CFEngine code into a tree-sitter tree +# --------------------------------------------------------------------------- + +_LANGUAGE = Language(tscfengine.language()) +_PARSER = Parser(_LANGUAGE) + + +def _parse(code: str) -> Node: + """Parse CFEngine source and return the root node.""" + tree = _PARSER.parse(code.encode("utf-8")) + return tree.root_node + + +def _format(code: str, line_length: int = 80) -> str: + """Format CFEngine source via format_policy_fin_fout and return the result.""" + fin = io.StringIO(code) + fout = io.StringIO() + format_policy_fin_fout(fin, fout, line_length, False) + return fout.getvalue() + + +def _find(root: Node, node_type: str) -> Node: + """Find the first descendant of the given type (depth-first).""" + if root.type == node_type: + return root + for child in root.children: + found = _find_opt(child, node_type) + if found: + return found + raise ValueError(f"No node of type {node_type!r} found") + + +def _find_opt(root: Node, node_type: str): + """Find the first descendant of the given type, or None.""" + if root.type == node_type: + return root + for child in root.children: + found = _find_opt(child, node_type) + if found: + return found + return None + + +def _find_all(root: Node, node_type: str) -> list[Node]: + """Find all descendants of the given type (depth-first).""" + results = [] + if root.type == node_type: + results.append(root) + for child in root.children: + results.extend(_find_all(child, node_type)) + return results + + +# --------------------------------------------------------------------------- +# text() +# --------------------------------------------------------------------------- + + +def test_text_returns_decoded_string(): + node = _leaf("identifier", "hello") + assert text(node) == "hello" + + +def test_text_returns_empty_for_none(): + node = MockNode("identifier", node_text=None) + node.text = None + assert text(node) == "" + + +# --------------------------------------------------------------------------- +# Formatter class +# --------------------------------------------------------------------------- + + +def test_formatter_empty_initial(): + fmt = Formatter() + assert fmt.empty is True + assert fmt.buffer == "" + assert fmt.previous is None + + +def test_formatter_print(): + fmt = Formatter() + fmt.print("hello", 0) + assert fmt.buffer == "hello" + assert fmt.empty is False + + +def test_formatter_print_with_indent(): + fmt = Formatter() + fmt.print("hello", 4) + assert fmt.buffer == " hello" + + +def test_formatter_print_multiple_lines(): + fmt = Formatter() + fmt.print("line1", 0) + fmt.print("line2", 2) + assert fmt.buffer == "line1\n line2" + + +def test_formatter_print_node(): + fmt = Formatter() + node = _leaf("identifier", "world") + fmt.print(node, 0) + assert fmt.buffer == "world" + + +def test_formatter_print_same_line(): + fmt = Formatter() + fmt.print("hello", 0) + fmt.print_same_line(";") + assert fmt.buffer == "hello;" + + +def test_formatter_print_same_line_node(): + fmt = Formatter() + fmt.print("x", 0) + fmt.print_same_line(_leaf(";")) + assert fmt.buffer == "x;" + + +def test_formatter_blank_line(): + fmt = Formatter() + fmt.print("a", 0) + fmt.blank_line() + fmt.print("b", 0) + assert fmt.buffer == "a\n\nb" + + +def test_formatter_print_lines(): + fmt = Formatter() + fmt.print_lines([" a", " b", " c"], indent=0) + assert fmt.buffer == " a\n b\n c" + + +def test_formatter_update_previous(): + fmt = Formatter() + n1 = _leaf("a", "a") + n2 = _leaf("b", "b") + assert fmt.update_previous(n1) is None + assert fmt.previous is n1 + assert fmt.update_previous(n2) is n1 + assert fmt.previous is n2 + + +# --------------------------------------------------------------------------- +# stringify_parameter_list +# --------------------------------------------------------------------------- + + def test_stringify_parameter_list(): assert stringify_parameter_list([]) == "" assert stringify_parameter_list(["foo"]) == "foo" @@ -28,6 +230,11 @@ def test_stringify_parameter_list(): assert stringify_parameter_list(parts) == "(x, y, z)" +# --------------------------------------------------------------------------- +# stringify_single_line_nodes / stringify_single_line_node +# --------------------------------------------------------------------------- + + def test_stringify_single_line_nodes(): assert stringify_single_line_nodes([]) == "" assert stringify_single_line_nodes([_leaf("identifier", "foo")]) == "foo" @@ -68,3 +275,604 @@ def test_stringify_single_line_nodes(): nodes = [_leaf("identifier", "x"), _leaf("=>"), inner] assert stringify_single_line_nodes(nodes) == 'x => func("arg")' + + +def test_stringify_single_line_node_leaf(): + assert stringify_single_line_node(_leaf("identifier", "foo")) == "foo" + + +def test_stringify_single_line_node_with_children(): + node = MockNode( + "attribute", + children=[ + _leaf("attribute_name", "string"), + _leaf("=>"), + _leaf("quoted_string", '"value"'), + ], + ) + assert stringify_single_line_node(node) == 'string => "value"' + + +# --------------------------------------------------------------------------- +# split_generic_list / maybe_split_generic_list +# --------------------------------------------------------------------------- + + +def test_split_generic_list_basic(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("string", '"b"')] + result = split_generic_list(nodes, 4, 80) + assert result == [' "a",', ' "b"'] + + +def test_maybe_split_generic_list_fits(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("string", '"b"')] + result = maybe_split_generic_list(nodes, 4, 80) + assert result == [' "a", "b"'] + + +def test_maybe_split_generic_list_too_long(): + nodes = [ + _leaf("string", '"aaaaaaaaaaaaaaaaaaaaaaaaa"'), + _leaf(","), + _leaf("string", '"bbbbbbbbbbbbbbbbbbbbbbbbb"'), + ] + result = maybe_split_generic_list(nodes, 4, 40) + assert len(result) == 2 + assert result[0].strip().startswith('"a') + assert result[1].strip().startswith('"b') + + +# --------------------------------------------------------------------------- +# split_rval_list / split_rval_call / split_rval +# --------------------------------------------------------------------------- + + +def test_split_rval_list(): + root = _parse('bundle agent x { vars: "v" slist => { "a", "b" }; }') + list_node = _find(root, "list") + result = split_rval_list(list_node, 6, 20) + assert result[0] == "{" + assert any('"a"' in line for line in result) + assert any('"b"' in line for line in result) + assert result[-1].strip() == "}" + + +def test_split_rval_call(): + root = _parse('bundle agent x { vars: "v" string => concat("a", "b"); }') + call_node = _find(root, "call") + result = split_rval_call(call_node, 6, 20) + assert result[0] == "concat(" + assert result[-1].strip() == ")" + + +def test_split_rval_dispatches_list(): + root = _parse('bundle agent x { vars: "v" slist => { "a", "b" }; }') + list_node = _find(root, "list") + result = split_rval(list_node, 6, 20) + assert result[0] == "{" + + +def test_split_rval_dispatches_call(): + root = _parse('bundle agent x { vars: "v" string => concat("a", "b"); }') + call_node = _find(root, "call") + result = split_rval(call_node, 6, 20) + assert result[0] == "concat(" + + +def test_split_rval_fallback(): + root = _parse('bundle agent x { vars: "v" string => "hello"; }') + string_node = _find(root, "quoted_string") + result = split_rval(string_node, 6, 80) + assert result == ['"hello"'] + + +def test_maybe_split_rval_fits(): + root = _parse('bundle agent x { vars: "v" string => "hello"; }') + string_node = _find(root, "quoted_string") + result = maybe_split_rval(string_node, 6, 10, 80) + assert result == ['"hello"'] + + +def test_maybe_split_rval_too_long(): + root = _parse('bundle agent x { vars: "v" slist => { "a", "b" }; }') + list_node = _find(root, "list") + result = maybe_split_rval(list_node, 6, 999, 80) + assert result[0] == "{" + + +def test_split_generic_value_call(): + root = _parse('bundle agent x { vars: "v" string => concat("a", "b"); }') + call_node = _find(root, "call") + result = split_generic_value(call_node, 6, 20) + assert result[0] == "concat(" + + +def test_split_generic_value_list(): + root = _parse('bundle agent x { vars: "v" slist => { "a", "b" }; }') + list_node = _find(root, "list") + result = split_generic_value(list_node, 6, 20) + assert result[0] == "{" + + +def test_split_generic_value_other(): + node = _leaf("quoted_string", '"hello"') + result = split_generic_value(node, 6, 80) + assert result == ['"hello"'] + + +# --------------------------------------------------------------------------- +# attempt_split_attribute / stringify +# --------------------------------------------------------------------------- + + +def test_attempt_split_attribute_with_list(): + root = _parse('bundle agent x { vars: "v" slist => { "a", "b" }; }') + attr = _find(root, "attribute") + result = attempt_split_attribute(attr, 6, 20) + assert len(result) > 1 + assert "slist => {" in result[0] + + +def test_attempt_split_attribute_with_string(): + root = _parse('bundle agent x { vars: "v" string => "hello"; }') + attr = _find(root, "attribute") + result = attempt_split_attribute(attr, 6, 80) + assert len(result) == 1 + assert 'string => "hello"' in result[0] + + +def test_stringify_short_attribute(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + attr = _find(root, "attribute") + result = stringify(attr, 6, 80) + assert len(result) == 1 + assert result[0] == ' string => "hi"' + + +def test_stringify_long_attribute_splits(): + root = _parse('bundle agent x { vars: "v" slist => { "aaa", "bbb" }; }') + attr = _find(root, "attribute") + result = stringify(attr, 6, 30) + assert len(result) > 1 + + +def test_stringify_non_attribute(): + node = _leaf("identifier", "hello") + result = stringify(node, 4, 80) + assert result == [" hello"] + + +# --------------------------------------------------------------------------- +# Stakeholder helpers +# --------------------------------------------------------------------------- + + +def test_get_stakeholder_list_present(): + root = _parse('bundle agent x { packages: "p" -> { "a", "b" } comment => "c"; }') + promise = _find(root, "promise") + list_node = _get_stakeholder_list(promise.children) + assert list_node is not None + assert list_node.type == "list" + + +def test_get_stakeholder_list_absent(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + promise = _find(root, "promise") + assert _get_stakeholder_list(promise.children) is None + + +def test_stakeholder_has_comments(): + root = _parse( + 'bundle agent x { packages: "p" -> {\n# comment\n"a" } comment => "c"; }' + ) + promise = _find(root, "promise") + assert _stakeholder_has_comments(promise.children) is True + + +def test_stakeholder_no_comments(): + root = _parse('bundle agent x { packages: "p" -> { "a", "b" } comment => "c"; }') + promise = _find(root, "promise") + assert _stakeholder_has_comments(promise.children) is False + + +def test_has_trailing_comma_true(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("string", '"b"'), _leaf(",")] + assert _has_trailing_comma(nodes) is True + + +def test_has_trailing_comma_false(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("string", '"b"')] + assert _has_trailing_comma(nodes) is False + + +def test_has_trailing_comma_empty(): + assert _has_trailing_comma([]) is False + + +def test_has_trailing_comma_comment_after(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("comment", "# x")] + assert _has_trailing_comma(nodes) is True + + +def test_promiser_text(): + root = _parse('bundle agent x { vars: "myvar" string => "hi"; }') + promise = _find(root, "promise") + assert _promiser_text(promise.children) == '"myvar"' + + +def test_promiser_text_absent(): + assert _promiser_text([_leaf("attribute", "x")]) is None + + +def test_promiser_line_with_stakeholder(): + root = _parse('bundle agent x { packages: "p" -> { "a", "b" } comment => "c"; }') + promise = _find(root, "promise") + line = _promiser_line_with_stakeholder(promise.children) + assert line is not None + assert line.startswith('"p"') + assert "-> { " in line + + +def test_promiser_line_without_stakeholder(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + promise = _find(root, "promise") + line = _promiser_line_with_stakeholder(promise.children) + assert line == '"v"' + + +def test_stakeholder_needs_splitting_with_comments(): + root = _parse( + 'bundle agent x { packages: "p" -> {\n# comment\n"a" } comment => "c"; }' + ) + promise = _find(root, "promise") + assert _stakeholder_needs_splitting(promise.children, 4, 80) is True + + +def test_stakeholder_needs_splitting_long_line(): + root = _parse( + 'bundle agent x { packages: "long_package" -> { "very long reason", "TICKET-1234" } comment => "c"; }' + ) + promise = _find(root, "promise") + assert _stakeholder_needs_splitting(promise.children, 4, 40) is True + + +def test_stakeholder_no_splitting_needed(): + root = _parse('bundle agent x { packages: "p" -> { "a" } comment => "c"; }') + promise = _find(root, "promise") + assert _stakeholder_needs_splitting(promise.children, 4, 80) is False + + +def test_has_stakeholder_true(): + root = _parse('bundle agent x { packages: "p" -> { "a" } comment => "c"; }') + promise = _find(root, "promise") + assert _has_stakeholder(promise.children) is True + + +def test_has_stakeholder_false(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + promise = _find(root, "promise") + assert _has_stakeholder(promise.children) is False + + +def test_format_stakeholder_elements_no_trailing_comma(): + nodes = [_leaf("string", '"a"'), _leaf(","), _leaf("string", '"b"')] + result = _format_stakeholder_elements(nodes, 8, 80) + assert len(result) == 1 + assert '"a", "b"' in result[0] + + +def test_format_stakeholder_elements_trailing_comma(): + nodes = [ + _leaf("string", '"a"'), + _leaf(","), + _leaf("string", '"b"'), + _leaf(","), + ] + result = _format_stakeholder_elements(nodes, 8, 80) + assert len(result) == 2 + + +def test_format_stakeholder_elements_with_comments(): + nodes = [ + _leaf("comment", "# note"), + _leaf("string", '"a"'), + _leaf(","), + _leaf("string", '"b"'), + ] + result = _format_stakeholder_elements(nodes, 8, 80) + assert any("# note" in line for line in result) + assert any('"a"' in line for line in result) + + +# --------------------------------------------------------------------------- +# can_single_line_promise +# --------------------------------------------------------------------------- + + +def test_can_single_line_promise_simple(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + promise = _find(root, "promise") + assert can_single_line_promise(promise, 4, 80) is True + + +def test_can_single_line_promise_too_long(): + root = _parse('bundle agent x { vars: "v" string => "hi"; }') + promise = _find(root, "promise") + assert can_single_line_promise(promise, 4, 10) is False + + +def test_can_single_line_promise_multi_attr(): + root = _parse('bundle agent x { vars: "v" if => "linux", string => "hi"; }') + promise = _find(root, "promise") + assert can_single_line_promise(promise, 4, 80) is False + + +def test_can_single_line_promise_not_a_promise(): + node = _leaf("attribute", "x") + assert can_single_line_promise(node, 4, 80) is False + + +def test_can_single_line_promise_with_stakeholder_and_attr(): + root = _parse('bundle agent x { packages: "p" -> { "a" } comment => "c"; }') + promise = _find(root, "promise") + assert can_single_line_promise(promise, 4, 200) is False + + +def test_can_single_line_promise_bare_promiser(): + root = _parse('bundle agent x { packages: "binutils"; }') + promise = _find(root, "promise") + assert can_single_line_promise(promise, 4, 80) is True + + +# --------------------------------------------------------------------------- +# Comment helpers +# --------------------------------------------------------------------------- + + +def test_is_empty_comment_bare_hash(): + node = MockNode("comment", "#") + node.prev_named_sibling = None + node.next_named_sibling = None + assert _is_empty_comment(node) is True + + +def test_is_empty_comment_real_comment(): + node = MockNode("comment", "# real comment") + node.prev_named_sibling = None + node.next_named_sibling = None + assert _is_empty_comment(node) is False + + +def test_is_empty_comment_between_comments(): + a = MockNode("comment", "# above") + b = MockNode("comment", "#") + c = MockNode("comment", "# below") + b.prev_named_sibling = a + b.next_named_sibling = c + a.type = "comment" + c.type = "comment" + assert _is_empty_comment(b) is False + + +def test_skip_comments_forward(): + c1 = MockNode("comment", "# a") + c2 = MockNode("comment", "# b") + target = MockNode("promise", '"x"') + c1.next_named_sibling = c2 + c2.next_named_sibling = target + assert _skip_comments(c1, "next") is target + + +def test_skip_comments_backward(): + target = MockNode("promise", '"x"') + c1 = MockNode("comment", "# a") + c2 = MockNode("comment", "# b") + c2.prev_named_sibling = c1 + c1.prev_named_sibling = target + assert _skip_comments(c2, "prev") is target + + +def test_skip_comments_none(): + assert _skip_comments(None, "next") is None + + +def test_skip_comments_no_non_comment(): + c1 = MockNode("comment", "# a") + c1.next_named_sibling = None + assert _skip_comments(c1, "next") is None + + +def test_comment_indent_next_is_promise(): + target = MockNode("promise", '"x"') + target.next_named_sibling = None + node = MockNode("comment", "# note") + node.next_named_sibling = target + node.prev_named_sibling = None + assert _comment_indent(node, 4) == 6 + + +def test_comment_indent_next_is_not_indented(): + target = MockNode("{", "{") + target.next_named_sibling = None + node = MockNode("comment", "# note") + node.next_named_sibling = target + node.prev_named_sibling = None + assert _comment_indent(node, 4) == 4 + + +def test_comment_indent_no_neighbors(): + node = MockNode("comment", "# note") + node.next_named_sibling = None + node.prev_named_sibling = None + assert _comment_indent(node, 4) == 4 + + +def test_comment_indent_prev_is_attribute(): + target = MockNode("attribute", "x") + target.prev_named_sibling = None + node = MockNode("comment", "# note") + node.next_named_sibling = None + node.prev_named_sibling = target + assert _comment_indent(node, 4) == 6 + + +# --------------------------------------------------------------------------- +# _needs_blank_line_before +# --------------------------------------------------------------------------- + + +def test_needs_blank_line_no_prev(): + node = MockNode("promise", '"x"') + node.prev_named_sibling = None + assert _needs_blank_line_before(node, 4, 80) is False + + +def test_needs_blank_line_bundle_sections(): + prev = MockNode("bundle_section", "vars:") + child = MockNode("bundle_section", "classes:") + child.prev_named_sibling = prev + assert _needs_blank_line_before(child, 0, 80) is True + + +def test_needs_blank_line_class_guard_after_promise(): + prev = MockNode("promise", '"x"') + child = MockNode("class_guarded_promises", "linux::") + child.prev_named_sibling = prev + assert _needs_blank_line_before(child, 4, 80) is True + + +def test_needs_blank_line_unrelated_types(): + prev = MockNode("{", "{") + child = MockNode("}", "}") + child.prev_named_sibling = prev + assert _needs_blank_line_before(child, 0, 80) is False + + +# --------------------------------------------------------------------------- +# autoformat / format_policy_fin_fout — integration tests +# --------------------------------------------------------------------------- + + +def test_format_hello_world(): + result = _format('bundle agent main\n{\nvars:\n"hello" string => "world";\n}') + assert "bundle agent main" in result + assert " vars:" in result + assert '"hello" string => "world";' in result + + +def test_format_idempotent(): + code = 'bundle agent main\n{\n vars:\n "v" string => "hi";\n}\n' + result = _format(code) + assert result == code + + +def test_format_indentation(): + code = 'bundle agent main\n{\nvars:\n"v"\nstring => "hi";\n}' + result = _format(code) + for line in result.strip().split("\n"): + if line.startswith("bundle") or line.startswith("{") or line.startswith("}"): + continue + assert line.startswith(" "), f"Expected indentation: {line!r}" + + +def test_format_multiple_bundles(): + code = "bundle agent a { } bundle agent b { }" + result = _format(code) + assert "bundle agent a" in result + assert "bundle agent b" in result + assert "\n\n" in result # blank line between bundles + + +def test_format_class_guard(): + code = 'bundle agent x { vars: linux:: "v" string => "hi"; }' + result = _format(code) + assert "linux::" in result + + +def test_format_comment_preserved(): + code = 'bundle agent x {\n# my comment\nvars:\n"v" string => "hi";\n}' + result = _format(code) + assert "# my comment" in result + + +def test_format_empty_comment_removed(): + code = 'bundle agent x {\nvars:\n#\n"v" string => "hi";\n}' + result = _format(code) + lines = [l.strip() for l in result.strip().split("\n")] + assert "#" not in lines + + +def test_format_stakeholder_inline(): + code = 'bundle agent x { packages: "p" -> { "a" }; }' + result = _format(code) + assert '"p" -> { "a" };' in result + + +def test_format_stakeholder_split(): + code = ( + "bundle agent x { packages: " + '"python3-rpm-macros" -> { "very long reason text here", "TICKET-1234" } ' + 'comment => "c"; }' + ) + result = _format(code, line_length=50) + assert "-> {" in result + lines = result.strip().split("\n") + assert any("}" in line and "comment" not in line for line in lines) + + +def test_format_stakeholder_with_attributes_multiline(): + code = 'bundle agent x { packages: "p" -> { "a", "b" } comment => "c"; }' + result = _format(code) + lines = result.strip().split("\n") + promiser_line = next(l for l in lines if '"p"' in l) + attr_line = next(l for l in lines if "comment" in l) + assert promiser_line != attr_line + + +def test_format_single_line_promises_grouped(): + code = ( + "bundle agent x\n" + "{\n" + " packages:\n" + ' "a" package_policy => "delete";\n' + ' "b" package_policy => "delete";\n' + "}\n" + ) + result = _format(code) + assert result == code # should be idempotent, no blank lines between + + +def test_format_multi_line_promise_separated(): + code = ( + 'bundle agent x { vars: "a" if => "linux", string => "x"; "b" string => "y"; }' + ) + result = _format(code) + assert "\n\n" in result # blank line between multi-line and next promise + + +def test_format_body_block(): + code = 'body common control { inputs => { "a.cf" }; }' + result = _format(code) + assert "body common control" in result + assert "inputs" in result + + +def test_format_long_list_wraps(): + code = ( + 'bundle agent x { vars: "v" slist => ' + '{ "aaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbb", "ccccccccccccccccc" }; }' + ) + result = _format(code, line_length=50) + lines = result.strip().split("\n") + assert len(lines) > 3 # should have wrapped + + +def test_format_line_length_respected(): + code = ( + 'bundle agent x { vars: "v" slist => ' + '{ "aaa", "bbb", "ccc", "ddd", "eee", "fff" }; }' + ) + result = _format(code, line_length=40) + for line in result.strip().split("\n"): + # Allow slight overshoot for long strings that can't be split + assert len(line) <= 80, f"Line too long: {line!r}"