Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 145 additions & 4 deletions document_clusterer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Sequence, cast
from typing import Dict, List, Sequence, TypedDict, cast

import numpy as np
from numpy.typing import NDArray
Expand All @@ -20,6 +20,18 @@
IntArray = NDArray[np.int_]


class TermCount(TypedDict):
term: str
count: int


class ClusterResult(TypedDict):
cluster: str
document_count: int
documents: list[str]
top_terms: list[TermCount]


def env_path(var_name: str, default: str) -> Path:
return Path(os.getenv(var_name, default))

Expand All @@ -28,7 +40,22 @@ def env_int(var_name: str, default: int) -> int:
return int(os.getenv(var_name, str(default)))


def _ensure_file_exists(path: Path, description: str) -> None:
if not path.exists():
raise FileNotFoundError(f"{description} does not exist: {path}")
if not path.is_file():
raise ValueError(f"{description} is not a file: {path}")


def _ensure_dir_exists(path: Path, description: str) -> None:
if not path.exists():
raise FileNotFoundError(f"{description} does not exist: {path}")
if not path.is_dir():
raise NotADirectoryError(f"{description} is not a directory: {path}")


def load_documents(data_path: Path) -> List[CleanedDocument]:
_ensure_file_exists(data_path, "Input data file")
LOGGER.info("Loading documents from %s", data_path)
with data_path.open("r", encoding="utf-8") as infile:
loaded = json.load(infile)
Expand Down Expand Up @@ -223,6 +250,50 @@ def save_summaries(summaries: Dict[int, list[tuple[str, int]]], output_dir: Path
outfile.write("\n")


def save_cluster_results(
cluster_to_documents: Dict[str, List[str]],
summaries: Dict[int, list[tuple[str, int]]],
output_dir: Path,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "cluster_results.json"
csv_path = output_dir / "cluster_results.csv"

LOGGER.info("Writing cluster results to %s and %s", json_path, csv_path)
formatted: list[ClusterResult] = []
for label, documents in sorted(cluster_to_documents.items(), key=lambda item: item[0]):
try:
numeric_label = int(label)
except ValueError:
numeric_label = None
top_terms = summaries.get(numeric_label, []) if numeric_label is not None else []
formatted.append(
{
"cluster": label,
"document_count": len(documents),
"documents": documents,
"top_terms": [{"term": term, "count": count} for term, count in top_terms],
}
)

with json_path.open("w", encoding="utf-8") as outfile:
json.dump(formatted, outfile, ensure_ascii=False, indent=2)

with csv_path.open("w", encoding="utf-8", newline="") as csvfile:
fieldnames = ["cluster", "document_count", "documents", "top_terms"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for entry in formatted:
writer.writerow(
{
"cluster": entry["cluster"],
"document_count": entry["document_count"],
"documents": ";".join(entry["documents"]),
"top_terms": ";".join(f'{term["term"]}:{term["count"]}' for term in entry["top_terms"]),
}
)


def copy_clusters(document_clusters: Dict[str, List[str]], stories_dir: Path, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -245,11 +316,58 @@ def copy_if_exists(src: Path, dst: Path) -> None:
dst.write_bytes(src.read_bytes())


def write_cluster_directories(
cluster_to_documents: Dict[str, List[str]],
summaries: Dict[int, list[tuple[str, int]]],
stories_dir: Path,
output_dir: Path,
*,
copy_documents: bool,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
LOGGER.info("Writing per-cluster directories to %s (copy_documents=%s)", output_dir, copy_documents)

for cluster_label, documents in cluster_to_documents.items():
cluster_dir = output_dir / cluster_label
cluster_dir.mkdir(parents=True, exist_ok=True)

try:
numeric_label = int(cluster_label)
except ValueError:
numeric_label = None

summary_path = cluster_dir / "summary.json"
summary_content = {
"cluster": cluster_label,
"document_count": len(documents),
"documents": documents,
"top_terms": [
{"term": term, "count": count}
for term, count in (summaries.get(numeric_label, []) if numeric_label is not None else [])
],
}
summary_path.write_text(json.dumps(summary_content, ensure_ascii=False, indent=2), encoding="utf-8")

if not copy_documents:
continue

for document_name in documents:
srcfile = stories_dir / document_name
dstfile = cluster_dir / document_name
LOGGER.debug("Copying %s -> %s", srcfile, dstfile)
copy_if_exists(srcfile, dstfile)


def cluster_documents(
data_path: Path,
stories_dir: Path,
output_dir: Path,
*,
assignments_output_dir: Path | None = None,
results_output_dir: Path | None = None,
cluster_output_dir: Path | None = None,
create_cluster_dirs: bool = True,
copy_cluster_documents: bool = True,
model_name: str = "all-MiniLM-L6-v2",
cluster_method: str = "kmeans",
cluster_count: int | None = 10,
Expand All @@ -264,6 +382,17 @@ def cluster_documents(
summary_top_n: int = 10,
assignments_basename: str = "cluster_assignments",
) -> Dict[str, List[str]]:
LOGGER.info("Preparing to cluster documents")
_ensure_file_exists(data_path, "Input data file")
_ensure_dir_exists(stories_dir, "Stories directory")

assignments_root = assignments_output_dir or output_dir
results_root = results_output_dir or output_dir
clusters_root = cluster_output_dir or output_dir

for destination in (assignments_root, results_root, clusters_root, output_dir):
destination.mkdir(parents=True, exist_ok=True)

documents = load_documents(data_path)
if not documents:
raise ValueError("No documents found to cluster.")
Expand All @@ -290,18 +419,30 @@ def cluster_documents(
save_assignments(
documents,
labels,
output_dir=output_dir,
output_dir=assignments_root,
basename=assignments_basename,
reduced_embeddings=reduced_embeddings,
)

summaries = summarize_clusters(documents, labels, top_n=summary_top_n)
save_summaries(summaries, output_dir)
save_summaries(summaries, results_root)

cluster_to_documents: Dict[str, List[str]] = defaultdict(list)
for document, label in zip(documents, labels):
cluster_label = str(label) if label != -1 else "noise"
cluster_to_documents[cluster_label].append(document["filename"])

copy_clusters(cluster_to_documents, stories_dir, output_dir)
save_cluster_results(cluster_to_documents, summaries, results_root)

if create_cluster_dirs:
write_cluster_directories(
cluster_to_documents,
summaries,
stories_dir,
clusters_root,
copy_documents=copy_cluster_documents,
)
else:
LOGGER.info("Skipping per-cluster directory creation")

return cluster_to_documents