diff --git a/utils/check_forward_call_docstrings.py b/utils/check_forward_call_docstrings.py index b4679f33bcda..9e04d7608882 100644 --- a/utils/check_forward_call_docstrings.py +++ b/utils/check_forward_call_docstrings.py @@ -17,9 +17,11 @@ pipelines) match the method's docstring exactly: * every signature argument has an entry in the ``Args:`` / - ``Arguments:`` / ``Parameters:`` section, and + ``Arguments:`` / ``Parameters:`` section, * every documented argument still exists in the signature - (stale entries from removed/renamed args are flagged). + (stale entries from removed/renamed args are flagged), and +* when the method has a non-``None`` return annotation, the docstring has + a ``Returns:`` / ``Return:`` / ``Yields:`` section. A "main" class is detected via its base classes — models inherit from ``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods @@ -33,6 +35,11 @@ Optionally restrict to specific files: python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py + +Auto-fix stale (documented-but-removed) entries — missing entries are never +auto-added (no placeholders), only stale ones are removed: + + python utils/check_forward_call_docstrings.py --fix """ from __future__ import annotations @@ -93,6 +100,17 @@ def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef | return None +def _docstring_node(func: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.Expr | None: + if ( + func.body + and isinstance(func.body[0], ast.Expr) + and isinstance(func.body[0].value, ast.Constant) + and isinstance(func.body[0].value.value, str) + ): + return func.body[0] + return None + + def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]: args = func.args collected: list[str] = [] @@ -103,6 +121,30 @@ def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[s return collected +def _has_meaningful_return(func: ast.FunctionDef | ast.AsyncFunctionDef) -> bool: + """True iff the method has a return annotation other than ``None`` or ``NoReturn``.""" + ret = func.returns + if ret is None: # no annotation at all + return False + if isinstance(ret, ast.Constant) and ret.value is None: # `-> None` + return False + # `-> NoReturn` or `-> typing.NoReturn` + if isinstance(ret, ast.Name) and ret.id == "NoReturn": + return False + if isinstance(ret, ast.Attribute) and ret.attr == "NoReturn": + return False + return True + + +def _has_returns_section(docstring: str | None) -> bool: + if not docstring: + return False + for line in docstring.splitlines(): + if line.strip() in {"Returns:", "Return:", "Yields:", "Yield:"}: + return True + return False + + def _extract_documented_args(docstring: str | None) -> set[str]: """Extract argument names listed in an Args/Arguments/Parameters section. @@ -180,10 +222,9 @@ def check_file(path: Path, kind: str) -> list[str]: if method is None: continue sig_args = _signature_arg_names(method) - if not sig_args: - continue sig_set = set(sig_args) - documented = _extract_documented_args(ast.get_docstring(method)) + docstring_text = ast.get_docstring(method) + documented = _extract_documented_args(docstring_text) missing = [a for a in sig_args if a not in documented] stale = sorted(documented - sig_set) if missing: @@ -196,9 +237,137 @@ def check_file(path: Path, kind: str) -> list[str]: f"{rel}:{method.lineno}: {node.name}.{method_name} documents " f"argument(s) not in the signature: {', '.join(stale)}" ) + if _has_meaningful_return(method) and not _has_returns_section(docstring_text): + return_repr = ast.unparse(method.returns) + ds = _docstring_node(method) + if ds is None: + where = " (method has no docstring)" + else: + where = f' (add it just above the closing """ on line {ds.end_lineno})' + errors.append( + f"{rel}:{method.lineno}: {node.name}.{method_name} returns " + f"`{return_repr}` but the docstring has no Returns: section{where}" + ) return errors +def fix_file(path: Path, kind: str) -> list[str]: + """Remove stale arg entries (documented but not in signature) in-place. + + Missing-in-signature → docstring entries are NOT added (no placeholders). + Returns a list of ``"ClassName.method: removed name1, name2"`` strings + describing what was removed. + """ + method_name = "forward" if kind == "model" else "__call__" + base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE + + source = path.read_text(encoding="utf-8") + try: + tree = ast.parse(source) + except (SyntaxError, UnicodeDecodeError): + return [] + + lines = source.splitlines(keepends=True) + # (start_idx, end_idx_exclusive) ranges of lines to drop. + deletions: list[tuple[int, int]] = [] + summaries: list[str] = [] + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + if base_class not in _base_class_names(node): + continue + method = _find_method(node, method_name) + if method is None: + continue + # Method must start with a string docstring expression. + if not ( + method.body + and isinstance(method.body[0], ast.Expr) + and isinstance(method.body[0].value, ast.Constant) + and isinstance(method.body[0].value.value, str) + ): + continue + + sig_set = set(_signature_arg_names(method)) + documented = _extract_documented_args(ast.get_docstring(method)) + stale = documented - sig_set + if not stale: + continue + + docstring_expr = method.body[0] + doc_start = docstring_expr.lineno - 1 # 0-indexed + doc_end = docstring_expr.end_lineno - 1 # 0-indexed, inclusive + + # Locate the Args/Arguments/Parameters header in raw source. + args_idx: int | None = None + header_indent = 0 + for i in range(doc_start, doc_end + 1): + stripped = lines[i].strip() + if stripped in {"Args:", "Arguments:", "Parameters:"}: + args_idx = i + header_indent = len(lines[i]) - len(lines[i].lstrip()) + break + if args_idx is None: + continue + + # First non-empty line after the header sets the per-entry indent. + entry_indent: int | None = None + for i in range(args_idx + 1, doc_end + 1): + stripped = lines[i].strip() + if not stripped: + continue + entry_indent = len(lines[i]) - len(lines[i].lstrip()) + break + if entry_indent is None or entry_indent <= header_indent: + continue + + # Walk entries; each entry spans from its header line up to (but not + # including) the next entry header / section header / end of docstring. + current_name: str | None = None + current_start: int = -1 + end_of_args: int | None = None + + for i in range(args_idx + 1, doc_end + 1): + line = lines[i] + stripped = line.strip() + if not stripped: + continue + indent = len(line) - len(line.lstrip()) + + if indent <= header_indent and stripped in SECTION_HEADERS: + end_of_args = i + break + + if indent == entry_indent: + m = _ARG_HEADER_RE.match(stripped) + if m: + if current_name in stale: + deletions.append((current_start, i)) + current_name = m.group(1) + current_start = i + + if current_name in stale: + end = end_of_args if end_of_args is not None else doc_end + # Trailing blank lines belong to inter-section spacing (or the + # blank line before the closing """), not to this entry. + while end > current_start + 1 and not lines[end - 1].strip(): + end -= 1 + deletions.append((current_start, end)) + + summaries.append(f"{node.name}.{method_name}: removed {', '.join(sorted(stale))}") + + if not deletions: + return [] + + deletions.sort() + new_lines = list(lines) + for start, end in reversed(deletions): + del new_lines[start:end] + path.write_text("".join(new_lines), encoding="utf-8") + return summaries + + def _kind_for_path(path: Path) -> str | None: parts = path.resolve().parts if "pipelines" in parts: @@ -224,6 +393,15 @@ def main() -> int: "(in sorted order) from each of models/ and pipelines/." ), ) + parser.add_argument( + "--fix", + action="store_true", + help=( + "Remove stale (documented-but-not-in-signature) argument entries from " + "docstrings in-place. Missing-in-docstring entries are NOT auto-added " + "(no placeholders) and will still be reported." + ), + ) args = parser.parse_args() targets: list[tuple[Path, str]] = [] @@ -253,6 +431,17 @@ def main() -> int: for p in pipeline_files: targets.append((p, "pipeline")) + if args.fix: + fix_summaries: list[str] = [] + for path, kind in targets: + for summary in fix_file(path, kind): + fix_summaries.append(f"{path.relative_to(REPO_ROOT)}: {summary}") + if fix_summaries: + print("Removed stale docstring entries:") + print("\n".join(f" {s}" for s in fix_summaries)) + else: + print("No stale docstring entries to remove.") + all_errors: list[str] = [] for path, kind in targets: all_errors.extend(check_file(path, kind)) @@ -263,6 +452,14 @@ def main() -> int: f"\nFound {len(all_errors)} docstring/signature mismatch(es).", file=sys.stderr, ) + if not args.fix and any("documents argument(s) not in the signature" in e for e in all_errors): + print( + "Hint: run `python utils/check_forward_call_docstrings.py --fix` " + "to remove the stale argument entries flagged above. " + "(Missing-in-docstring entries must be added manually — the tool " + "never inserts placeholders.)", + file=sys.stderr, + ) return 1 print("All forward/__call__ arguments are documented.")