diff --git a/.coverage b/.coverage
index cb8c5ebe..8380a2ac 100644
Binary files a/.coverage and b/.coverage differ
diff --git a/load_corpus.py b/load_corpus.py
index 668e7524..3936a8ac 100755
--- a/load_corpus.py
+++ b/load_corpus.py
@@ -1,117 +1,219 @@
from superstyl.load import load_corpus
-from superstyl.load_from_config import load_corpus_from_config
+from superstyl.config import Config
import json
# TODO: eliminate features that occur only n times ?
# Do the Moisl Selection ?
-# TODO: document the new 'lemma' feat for TEI loading
if __name__ == '__main__':
import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('-s', nargs='+', help="paths to files or to json config file", required=True)
- parser.add_argument('--json', action='store_true', help="indicates that the path provided with -s is a JSON config file, "
- "containing all the options to load the corpus/features")
- parser.add_argument('-o', action='store', help="optional base name of output files", type=str, default=False)
- parser.add_argument('-f', action="store", help="optional list of features, either in json (generated by"
- " Superstyl) or simple txt (one word per line)", default=False)
- parser.add_argument('-t', action='store', help="types of features (words, chars, affixes - "
- "as per Sapkota et al. 2015 -, as well as lemma or pos, met_line, "
- "met_syll (those four last only for TEI files with proper annotation)"
- , type=str,
- default="words", choices=["words", "chars", "affixes", "pos", "lemma", "met_line", "met_syll"])
- parser.add_argument('-n', action='store', help="n grams lengths (default 1)", default=1, type=int)
- parser.add_argument('-k', action='store', help="How many most frequent?", default=5000, type=int)
- parser.add_argument('--freqs', action='store', help="relative, absolute or binarised freqs",
+ parser = argparse.ArgumentParser(
+ description="Load a corpus and extract features for stylometric analysis."
+ )
+ parser.add_argument('-s',
+ nargs='+',
+ help="paths to files or to json config file",
+ required=True
+ )
+ parser.add_argument('--json',
+ action='store_true',
+ help="indicates that the path provided with -s is a JSON config file, "
+ "containing all the options to load the corpus/features"
+ )
+ parser.add_argument('-o',
+ action='store',
+ help="optional base name of output files",
+ type=str,
+ default=False
+ )
+ # Feature list
+ parser.add_argument('-f',
+ action="store",
+ help="optional list of features, either in json (generated by"
+ " Superstyl) or simple txt (one word per line)",
+ default=False
+ )
+ parser.add_argument('-t',
+ action='store',
+ help="types of features (words, chars, affixes - "
+ "as per Sapkota et al. 2015 -, as well as lemma or pos, met_line, "
+ "met_syll (those four last only for TEI files with proper annotation)",
+ type=str, default="words",
+ choices=["words", "chars", "affixes", "pos", "lemma", "met_line", "met_syll"]
+ )
+ parser.add_argument('-n',
+ action='store',
+ help="n grams lengths (default 1)",
+ default=1,
+ type=int)
+ parser.add_argument('-k',
+ action='store',
+ help="How many most frequent features?",
+ default=5000,
+ type=int
+ )
+ parser.add_argument('--freqs',
+ action='store',
+ help="relative, absolute or binarised freqs",
default="relative",
choices=["relative", "absolute", "binary"]
)
- parser.add_argument('-x', action='store', help="format (txt, xml, tei, or txm) WARNING: only txt is fully implemented",
- default="txt",
+ parser.add_argument('-x',
+ action='store',
+ help="format (txt, xml, tei, or txm) WARNING: only txt is fully implemented",
+ default="txt",
choices=["txt", "xml", "tei", 'txm']
)
- parser.add_argument('--sampling', action='store_true', help="Sample the texts?", default=False)
- parser.add_argument('--sample_units', action='store', help="Units of length for sampling "
- "(words, verses; default: words)",
- choices=["words", "verses"],
- default="words", type=str)
- parser.add_argument('--sample_size', action='store', help="Size for sampling (default: 3000)", default=3000, type=int)
- parser.add_argument('--sample_step', action='store', help="Step for sampling with overlap (default is no overlap)", default=None, type=int)
- parser.add_argument('--max_samples', action='store', help="Maximum number of (randomly selected) samples per class, e.g. author (default is all)",
- default=None, type=int)
- parser.add_argument('--samples_random', action='store_true',
- help="Should random sampling with replacement be performed instead of continuous sampling (default: false)",
- default=False)
- parser.add_argument('--keep_punct', action='store_true', help="whether to keep punctuation and caps (default is False)",
- default=False)
- parser.add_argument('--keep_sym', action='store_true',
- help="if true, same as keep_punct, plus no Unidecode, and numbers are kept as well (default is False)",
- default=False)
- parser.add_argument('--no_ascii', action='store_true',
- help="disables the conversion to ascii as per the Unidecode module. Useful for non Latin alphabet (default is conversion to ASCII)",
- default=False)
- parser.add_argument('--identify_lang', action='store_true',
- help="if true, should the language of each text be guessed, using langdetect (default is False)",
- default=False)
- parser.add_argument('--embedding', action="store", help="optional path to a word2vec embedding in txt format to compute frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)",
- default=False)
- parser.add_argument('--neighbouring_size', action="store", help="size of semantic neighbouring in the embedding (n closest neighbours)",
- default=10, type=int)
- parser.add_argument('--culling', action="store",
- help="percentage value for culling, meaning in what percentage of samples should a feature be present to be retained (default is 0, meaning no culling)",
- default=0, type=float)
-
+ parser.add_argument('--sampling',
+ action='store_true',
+ help="Sample the texts?",
+ default=False
+ )
+ parser.add_argument('--sample_units',
+ action='store',
+ help="Units of length for sampling (words, verses; default: words)",
+ choices=["words", "verses"],
+ default="words",
+ type=str
+ )
+ parser.add_argument('--sample_size',
+ action='store',
+ help="Size for sampling (default: 3000)",
+ default=3000,
+ type=int
+ )
+ parser.add_argument('--sample_step',
+ action='store',
+ help="Step for sampling with overlap (default is no overlap)",
+ default=None,
+ type=int
+ )
+ parser.add_argument('--max_samples',
+ action='store',
+ help="Maximum number of (randomly selected) samples per class, e.g. author (default is all)",
+ default=None,
+ type=int
+ )
+ parser.add_argument('--samples_random',
+ action='store_true',
+ help="Should random sampling with replacement be performed " \
+ "instead of continuous sampling (default: false)",
+ default=False
+ )
+ parser.add_argument('--keep_punct',
+ action='store_true',
+ help="whether to keep punctuation and caps (default is False)",
+ default=False
+ )
+ parser.add_argument('--keep_sym',
+ action='store_true',
+ help="if true, same as keep_punct, plus no Unidecode, "
+ "and numbers are kept as well (default is False)",
+ default=False
+ )
+ parser.add_argument('--no_ascii',
+ action='store_true',
+ help="disables the conversion to ascii as per the Unidecode module. " \
+ "Useful for non Latin alphabet (default is conversion to ASCII)",
+ default=False
+ )
+ parser.add_argument('--identify_lang',
+ action='store_true',
+ help="if true, should the language of each text be guessed, " \
+ "using langdetect (default is False)",
+ default=False
+ )
+ parser.add_argument('--embedding',
+ action="store",
+ help="optional path to a word2vec embedding in txt format to compute " \
+ "frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)",
+ default=False
+ )
+ parser.add_argument('--neighbouring_size',
+ action="store",
+ help="size of semantic neighbouring in the embedding (n closest neighbours)",
+ default=10,
+ type=int
+ )
+ parser.add_argument('--culling',
+ action="store",
+ help="percentage value for culling, meaning in what " \
+ "percentage of samples should a feature be present " \
+ "to be retained (default is 0, meaning no culling)",
+ default=0,
+ type=float)
args = parser.parse_args()
+ # Load feature list if provided
+ my_feats = None
if args.f:
with open(args.f, 'r') as f:
- if args.f.split(".")[-1] == "json":
+ if args.f.endswith(".json"):
print(".......loading preexisting feature list from json.......")
my_feats = json.loads(f.read())
-
- elif args.f.split(".")[-1] == "txt":
+ elif args.f.endswith(".txt"):
print(".......loading preexisting feature list from txt.......")
my_feats = [[feat.rstrip(), 0] for feat in f.readlines()]
-
else:
print(".......unknown feature list format. Ignoring.......")
- my_feats = None
-
- else:
- my_feats = None
- if args.json:
- corpus, my_feats = load_corpus_from_config(args.s)
+ elif args.config:
+ # Load from new-style JSON config file
+ config = Config.from_json(args.config)
+ # Override paths if provided via CLI
+ if args.s:
+ config.corpus.paths = args.s
+
else:
- corpus, my_feats = load_corpus(args.s, feat_list=my_feats, feats=args.t, n=args.n, k=args.k,
- freqsType=args.freqs, format=args.x,
- sampling=args.sampling, units=args.sample_units,
- size=args.sample_size, step=args.sample_step, max_samples=args.max_samples,
- samples_random=args.samples_random,
- keep_punct=args.keep_punct, keep_sym=args.keep_sym, no_ascii=args.no_ascii,
- identify_lang=args.identify_lang,
- embedding=args.embedding, neighbouring_size=args.neighbouring_size,
- culling=args.culling
- )
-
- print(".......saving results.......")
-
+ if not args.s:
+ parser.error("-s (paths) is required when not using --config")
+
+ config = Config.from_kwargs(
+ data_paths=args.s,
+ feats=args.t,
+ n=args.n,
+ k=args.k,
+ freqsType=args.freqs,
+ format=args.x,
+ sampling=args.sampling,
+ units=args.sample_units,
+ size=args.sample_size,
+ step=args.sample_step,
+ max_samples=args.max_samples,
+ samples_random=args.samples_random,
+ keep_punct=args.keep_punct,
+ keep_sym=args.keep_sym,
+ no_ascii=args.no_ascii,
+ identify_lang=args.identify_lang,
+ embedding=args.embedding,
+ neighbouring_size=args.neighbouring_size,
+ culling=args.culling
+ )
+
+ # Inject my_feats if provided
+ if my_feats and config.features:
+ config.features[0].feat_list = my_feats
+
+ # Load corpus
+ corpus, my_feats = load_corpus(config=config)
+
+ # Determine output file names
if args.o:
feat_file = args.o + "_feats.json"
corpus_file = args.o + ".csv"
-
else:
- feat_file = "feature_list_{}{}grams{}mf.json".format(args.t, args.n, args.k)
- corpus_file = "feats_tests_n{}_k_{}.csv".format(args.n, args.k)
+ feat_file = f"feature_list_{args.t}{args.n}grams{args.k}mf.json"
+ corpus_file = f"feats_tests_n{args.n}_k_{args.k}.csv"
- #if not args.f and :
+ # Save results
+ print(".......saving results.......")
+
with open(feat_file, "w") as out:
out.write(json.dumps(my_feats, ensure_ascii=False, indent=0))
- print("Features list saved to " + feat_file)
+ print(f"Features list saved to {feat_file}")
+ # Save corpus
corpus.to_csv(corpus_file)
- print("Corpus saved to " + corpus_file)
-
-
+ print(f"Corpus saved to {corpus_file}")
\ No newline at end of file
diff --git a/split.py b/split.py
index 6ee63c9d..3094fa90 100755
--- a/split.py
+++ b/split.py
@@ -1,3 +1,7 @@
+"""
+Command-line tool for splitting datasets.
+"""
+
import superstyl.preproc.select as sel
@@ -12,28 +16,28 @@
default=False)
parser.add_argument('-m', action="store", help="path to metadata file", required=False)
parser.add_argument('-e', action="store", help="path to excludes file", required=False)
- parser.add_argument('--lang', action="store", help="analyse only file in this language (optional, for initial split only)", required=False)
- parser.add_argument('--nosplit', action="store_true", help="no split (do not provide split file)", default=False)
+ parser.add_argument('--lang', action="store",
+ help="analyse only file in this language (optional, for initial split only)",
+ required=False)
+ parser.add_argument('--nosplit', action="store_true",
+ help="no split (do not provide split file)",
+ default=False)
+ parser.add_argument('--split_ratio', action="store", type=float,
+ help="validation split ratio (default: 0.1 = 10%%)",
+ default=0.1)
args = parser.parse_args()
- if args.nosplit:
- sel.read_clean(path=args.path,
- metadata_path=args.m,
- excludes_path=args.e,
- savesplit="split_nosplit.json",
- lang=args.lang
- )
+ if args.s:
+ # Apply existing selection
+ sel.apply_selection(path=args.path, presplit_path=args.s)
else:
-
- if not args.s:
- # to create initial selection
- sel.read_clean_split(path=args.path,
- metadata_path=args.m,
- excludes_path=args.e,
- savesplit="split.json",
- lang=args.lang
- )
-
- else:
- # to load and apply a selection
- sel.apply_selection(path=args.path, presplit_path=args.s)
+ # Create new selection (with or without split)
+ sel.read_clean(
+ path=args.path,
+ metadata_path=args.m,
+ excludes_path=args.e,
+ savesplit="split_nosplit.json" if args.nosplit else "split.json",
+ lang=args.lang,
+ split=not args.nosplit,
+ split_ratio=args.split_ratio
+ )
diff --git a/superstyl/__init__.py b/superstyl/__init__.py
index 940f61fc..67ab2010 100755
--- a/superstyl/__init__.py
+++ b/superstyl/__init__.py
@@ -1,2 +1,36 @@
from superstyl.load import load_corpus
-from superstyl.svm import train_svm
\ No newline at end of file
+from superstyl.svm import train_svm, plot_rolling, plot_coefficients
+from superstyl.config import (
+ Config,
+ CorpusConfig,
+ FeatureConfig,
+ SamplingConfig,
+ NormalizationConfig,
+ SVMConfig,
+ get_config,
+ set_config,
+ reset_config
+)
+
+__all__ = [
+ # Main functions
+ 'load_corpus',
+ 'load_corpus_with_config',
+ 'train_svm',
+ 'train_svm_with_config',
+ 'plot_rolling',
+ 'plot_coefficients',
+
+ # Configuration classes
+ 'Config',
+ 'CorpusConfig',
+ 'FeatureConfig',
+ 'SamplingConfig',
+ 'NormalizationConfig',
+ 'SVMConfig',
+
+ # Configuration management
+ 'get_config',
+ 'set_config',
+ 'reset_config',
+]
\ No newline at end of file
diff --git a/superstyl/config.py b/superstyl/config.py
new file mode 100644
index 00000000..4a0aa6fa
--- /dev/null
+++ b/superstyl/config.py
@@ -0,0 +1,362 @@
+from dataclasses import dataclass, field, fields
+from typing import Optional, List, Any, Dict, Type, TypeVar
+import json
+
+
+T = TypeVar('T', bound='BaseConfig')
+
+
+@dataclass
+class BaseConfig:
+ """
+ Base configuration class providing common functionality.
+
+ All configuration classes inherit from this to share:
+ - to_dict() serialization
+ - from_dict() deserialization
+ - Validation hooks
+ """
+
+ def to_dict(self) -> Dict[str, Any]:
+ result = {}
+ for f in fields(self):
+ value = getattr(self, f.name)
+ if isinstance(value, BaseConfig):
+ result[f.name] = value.to_dict()
+ elif isinstance(value, list) and value and isinstance(value[0], BaseConfig):
+ result[f.name] = [v.to_dict() for v in value]
+ else:
+ result[f.name] = value
+ return result
+
+ @classmethod
+ def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
+ return cls(**data)
+
+ def validate(self) -> None:
+ pass
+
+
+@dataclass
+class NormalizationConfig(BaseConfig):
+ """
+ Configuration for text normalization.
+ """
+ keep_punct: bool = False
+ keep_sym: bool = False
+ no_ascii: bool = False
+
+
+@dataclass
+class SamplingConfig(BaseConfig):
+ """
+ Configuration for text sampling.
+ """
+ enabled: bool = False
+ units: str = "words"
+ size: int = 3000
+ step: Optional[int] = None
+ max_samples: Optional[int] = None
+ random: bool = False
+
+ def __post_init__(self):
+ self.validate()
+
+ def validate(self) -> None:
+ if self.units not in ["words", "verses"]:
+ raise ValueError(f"Invalid sampling units: {self.units}.")
+ if self.random and self.step is not None:
+ raise ValueError("Random sampling is not compatible with step.")
+ if self.random and self.max_samples is None:
+ raise ValueError("Random sampling needs max_samples.")
+
+
+@dataclass
+class FeatureConfig(BaseConfig):
+ """
+ Configuration for feature extraction.
+ """
+ name: Optional[str] = None # For multi-feature identification
+ type: str = "words"
+ n: int = 1
+ k: int = 5000
+ freq_type: str = "relative"
+ feat_list: Optional[List] = None
+ feat_list_path: Optional[str] = None # Path to load feat_list from
+ embedding: Optional[str] = None
+ neighbouring_size: int = 10
+ culling: float = 0
+
+ VALID_TYPES = ["words", "chars", "affixes", "lemma", "pos", "met_line", "met_syll"]
+ VALID_FREQ_TYPES = ["relative", "absolute", "binary"]
+
+ def __post_init__(self):
+ self.validate()
+ self._load_feat_list_if_needed()
+
+ def validate(self) -> None:
+ if self.type not in self.VALID_TYPES:
+ raise ValueError(f"Invalid feature type: {self.type}.")
+ if self.freq_type not in self.VALID_FREQ_TYPES:
+ raise ValueError(f"Invalid frequency type: {self.freq_type}.")
+ if self.n < 1:
+ raise ValueError("n must be a positive integer.")
+
+ def _load_feat_list_if_needed(self) -> None:
+ """
+ Load feature list from file if path is specified.
+ """
+ if self.feat_list_path and self.feat_list is None:
+ with open(self.feat_list_path, 'r') as f:
+ if self.feat_list_path.endswith('.json'):
+ self.feat_list = json.load(f)
+ elif self.feat_list_path.endswith('.txt'):
+ self.feat_list = [[feat.strip(), 0] for feat in f.readlines()]
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> 'FeatureConfig':
+ # Filter out unknown keys for backward compatibility
+ valid_keys = {f.name for f in fields(cls)}
+ filtered_data = {k: v for k, v in data.items() if k in valid_keys}
+ return cls(**filtered_data)
+
+
+@dataclass
+class CorpusConfig(BaseConfig):
+ """
+ Configuration for corpus loading.
+ """
+ paths: List[str] = field(default_factory=list)
+ format: str = "txt"
+ identify_lang: bool = False
+
+ VALID_FORMATS = ["txt", "xml", "tei", "txm"]
+
+ def __post_init__(self):
+ self.validate()
+
+ def validate(self) -> None:
+ if self.format not in self.VALID_FORMATS:
+ raise ValueError(f"Invalid format: {self.format}.")
+
+
+@dataclass
+class SVMConfig(BaseConfig):
+ """
+ Configuration for SVM training.
+ """
+ cross_validate: Optional[str] = None
+ k: int = 0
+ dim_reduc: Optional[str] = None
+ norms: bool = True
+ balance: Optional[str] = None
+ class_weights: bool = False
+ kernel: str = "LinearSVC"
+ final_pred: bool = False
+ get_coefs: bool = False
+ plot_rolling: bool = False
+ plot_smoothing: int = 3
+
+ VALID_CV = [None, "leave-one-out", "k-fold", "group-k-fold"]
+ VALID_DIM_REDUC = [None, "pca"]
+ VALID_BALANCE = [None, "downsampling", "Tomek", "upsampling", "SMOTE", "SMOTETomek"]
+ VALID_KERNELS = ["LinearSVC", "linear", "sigmoid", "rbf", "poly"]
+
+ def __post_init__(self):
+ self.validate()
+
+ def validate(self) -> None:
+ if self.cross_validate not in self.VALID_CV:
+ raise ValueError(f"Invalid cross_validate: {self.cross_validate}.")
+ if self.dim_reduc not in self.VALID_DIM_REDUC:
+ raise ValueError(f"Invalid dim_reduc: {self.dim_reduc}.")
+ if self.balance not in self.VALID_BALANCE:
+ raise ValueError(f"Invalid balance: {self.balance}.")
+ if self.kernel not in self.VALID_KERNELS:
+ raise ValueError(f"Invalid kernel: {self.kernel}.")
+
+
+@dataclass
+class Config(BaseConfig):
+ """
+ Main configuration class for SuperStyl.
+
+ Aggregates all sub-configurations and provides factory methods
+ to create configurations from various sources (CLI, JSON, dict).
+ """
+ corpus: CorpusConfig = field(default_factory=CorpusConfig)
+ features: List[FeatureConfig] = field(default_factory=lambda: [FeatureConfig()])
+ sampling: SamplingConfig = field(default_factory=SamplingConfig)
+ normalization: NormalizationConfig = field(default_factory=NormalizationConfig)
+ svm: SVMConfig = field(default_factory=SVMConfig)
+ output_prefix: Optional[str] = None
+
+ # Mapping from flat kwargs to nested config structure
+ # Format: 'kwarg_name': ('section', 'attr', optional_transform)
+ KWARGS_MAPPING = {
+ # Corpus
+ 'data_paths': ('corpus', 'paths', lambda x: x if isinstance(x, list) else [x]),
+ 'format': ('corpus', 'format', None),
+ 'identify_lang': ('corpus', 'identify_lang', None),
+
+ # Features (single feature mode)
+ 'feats': ('features', 'type', None),
+ 'n': ('features', 'n', None),
+ 'k': ('features', 'k', None),
+ 'freqsType': ('features', 'freq_type', None),
+ 'feat_list': ('features', 'feat_list', None),
+ 'embedding': ('features', 'embedding', lambda x: x if x else None),
+ 'neighbouring_size': ('features', 'neighbouring_size', None),
+ 'culling': ('features', 'culling', None),
+
+ # Sampling
+ 'sampling': ('sampling', 'enabled', None),
+ 'units': ('sampling', 'units', None),
+ 'size': ('sampling', 'size', None),
+ 'step': ('sampling', 'step', None),
+ 'max_samples': ('sampling', 'max_samples', None),
+ 'samples_random': ('sampling', 'random', None),
+
+ # Normalization
+ 'keep_punct': ('normalization', 'keep_punct', None),
+ 'keep_sym': ('normalization', 'keep_sym', None),
+ 'no_ascii': ('normalization', 'no_ascii', None),
+
+ # SVM
+ 'cross_validate': ('svm', 'cross_validate', None),
+ 'dim_reduc': ('svm', 'dim_reduc', None),
+ 'norms': ('svm', 'norms', None),
+ 'balance': ('svm', 'balance', None),
+ 'class_weights': ('svm', 'class_weights', None),
+ 'kernel': ('svm', 'kernel', None),
+ 'final_pred': ('svm', 'final_pred', None),
+ 'get_coefs': ('svm', 'get_coefs', None),
+ }
+
+ def validate(self) -> None:
+ """
+ Validate configuration consistency.
+ """
+ if not self.features:
+ raise ValueError("No features specified for extraction.")
+
+ if not self.corpus.paths:
+ raise ValueError("No paths specified for corpus loading.")
+
+ # Validate paths type
+ if not isinstance(self.corpus.paths, list):
+ raise TypeError("Paths in config must be either a list or a glob pattern string.")
+
+ for feat_config in self.features:
+ tei_only = ["lemma", "pos", "met_line", "met_syll"]
+ if feat_config.type in tei_only and self.corpus.format != "tei":
+ raise ValueError(f"{feat_config.type} requires TEI format.")
+ if feat_config.type in ["met_line", "met_syll"] and self.sampling.units not in ["verses"]:
+ raise ValueError(f"{feat_config.type} requires verses units.")
+
+ def save(self, path: str) -> None:
+ with open(path, 'w') as f:
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
+
+ @classmethod
+ def from_kwargs(cls, **kwargs) -> 'Config':
+ """
+ Build Config from flat kwargs (backward compatibility).
+
+ This allows the old-style function calls to work:
+ load_corpus(paths, feats="chars", n=3, keep_punct=True)
+ """
+ config_data = {
+ 'corpus': {},
+ 'features': {},
+ 'sampling': {},
+ 'normalization': {},
+ 'svm': {}
+ }
+
+ for kwarg_name, value in kwargs.items():
+ if value is None:
+ continue
+
+ if kwarg_name in cls.KWARGS_MAPPING:
+ section, attr, transform = cls.KWARGS_MAPPING[kwarg_name]
+ final_value = transform(value) if transform else value
+ config_data[section][attr] = final_value
+
+ # Build the config with proper nesting
+ return cls.from_dict(config_data)
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> 'Config':
+ kwargs = {}
+
+ if 'corpus' in data:
+ corpus_data = data['corpus'].copy()
+ # Handle 'paths' key variations
+ if 'paths' not in corpus_data and 'data_paths' in data:
+ corpus_data['paths'] = data['data_paths']
+ kwargs['corpus'] = CorpusConfig.from_dict(corpus_data) if corpus_data else CorpusConfig()
+
+ if 'features' in data:
+ features_data = data['features']
+ if isinstance(features_data, dict):
+ kwargs['features'] = [FeatureConfig.from_dict(features_data)]
+ elif isinstance(features_data, list):
+ if features_data:
+ kwargs['features'] = [FeatureConfig.from_dict(f) for f in features_data]
+ else:
+ kwargs['features'] = []
+
+ if 'sampling' in data and data['sampling']:
+ sampling_data = data['sampling'].copy()
+ # Handle alternative key names
+ if 'sample_size' in sampling_data:
+ sampling_data['size'] = sampling_data.pop('sample_size')
+ if 'samples_random' in sampling_data:
+ sampling_data['random'] = sampling_data.pop('samples_random')
+ if 'sample_step' in sampling_data:
+ sampling_data['step'] = sampling_data.pop('sample_step')
+ if 'sample_units' in sampling_data:
+ sampling_data['units'] = sampling_data.pop('sample_units')
+ kwargs['sampling'] = SamplingConfig.from_dict(sampling_data)
+
+ if 'normalization' in data and data['normalization']:
+ kwargs['normalization'] = NormalizationConfig.from_dict(data['normalization'])
+
+ if 'svm' in data and data['svm']:
+ kwargs['svm'] = SVMConfig.from_dict(data['svm'])
+
+ if 'output_prefix' in data:
+ kwargs['output_prefix'] = data['output_prefix']
+
+ return cls(**kwargs)
+
+ @classmethod
+ def from_json(cls, path: str) -> 'Config':
+ with open(path, 'r') as f:
+ data = json.load(f)
+
+ # Handle flat JSON format (paths at root level)
+ if 'paths' in data and 'corpus' not in data:
+ data['corpus'] = {'paths': data.pop('paths')}
+ if 'format' in data:
+ data['corpus']['format'] = data.pop('format')
+ if 'identify_lang' in data:
+ data['corpus']['identify_lang'] = data.pop('identify_lang')
+
+ return cls.from_dict(data)
+
+
+# Global configuration (optional singleton pattern)
+_current_config: Optional[Config] = None
+
+def get_config() -> Optional[Config]:
+ return _current_config
+
+def set_config(config: Config) -> None:
+ global _current_config
+ _current_config = config
+
+def reset_config() -> None:
+ global _current_config
+ _current_config = None
\ No newline at end of file
diff --git a/superstyl/load.py b/superstyl/load.py
index fe52683e..395b62ae 100644
--- a/superstyl/load.py
+++ b/superstyl/load.py
@@ -4,133 +4,208 @@
import superstyl.preproc.embedding as embed
import tqdm
import pandas
+from typing import Optional, List, Tuple, Union
+from superstyl.config import Config, FeatureConfig, NormalizationConfig
-def load_corpus(data_paths, feat_list=None, feats="words", n=1, k=5000, freqsType="relative", format="txt", sampling=False,
- units="words", size=3000, step=None, max_samples=None, samples_random=False, keep_punct=False, keep_sym=False,
- no_ascii=False,
- identify_lang=False, embedding=False, neighbouring_size=10, culling=0):
+
+def _load_single_feature(
+ myTexts: List[dict],
+ feat_config: FeatureConfig,
+ norm_config: NormalizationConfig,
+ use_provided_feat_list: bool = False,
+) -> Tuple[pandas.DataFrame, List]:
"""
- Main function to load a corpus from a collection of file, and an optional list of features to extract.
- :param data_paths: paths to the source files
- :param feat_list: an optional list of features (as created by load_corpus), default None
- :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'.
- Affixes are inspired by Sapkota et al. 2015, and include space_prefix, space_suffix, prefix, suffix, and,
- if keep_punct, punctuation n-grams. From TEI, pos, lemma, met_line or met_syll can
- be extracted; met_line is the prosodic (stress) annotation of a full verse; met_syll is a char n-gram of prosodic
- annotation
- :param n: n grams lengths (default 1)
- :param k: How many most frequent? The function takes the rank of k (if k is smaller than the total number of features),
- gets its frequencies, and only include features of superior or equal total frequencies.
- :param freqsType: return relative, absolute or binarised frequencies (default: relative)
- :param format: one of txt, xml or tei. /!\ only txt is fully implemented.
- :param sampling: whether to sample the texts, by cutting it into slices of a given length, until the last possible
- slice of this length, which means that often the end of the text will be eliminated (default False)
- :param units: units of length for sampling, one of 'words', 'verses' (default: words). 'verses' is only implemented
- for the 'tei' format
- :param size: the size of the samples (in units)
- :param step: step for sampling with overlap (default is step = size, which means no overlap).
- Reduce for overlapping slices
- :param max_samples: Maximum number of (randomly selected) samples per author/class (default is all)
- :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false)
- :param keep_punct: whether to keep punctuation and caps (default is False)
- :param keep_sym: same as keep_punct, and numbers are kept as well (default is False). /!\ does not
- actually keep symbols
- :param no_ascii: disables conversion to ASCII (default is conversion)
- :param identify_lang: if true, the language of each text will be guessed, using langdetect (default is False)
- :param embedding: optional path to a word2vec embedding in txt format to compute frequencies among a set of
- semantic neighbourgs (i.e., pseudo-paronyms)
- :param neighbouring_size: size of semantic neighbouring in the embedding (as per gensim most_similar,
- with topn=neighbouring_size)
- :param culling percentage value for culling, meaning in what percentage of samples should a feature be present to be retained (default is 0, meaning no culling)
- :return a pandas dataFrame of text metadata and feature frequencies; a global list of features with their frequencies
+ Extract features for a single FeatureConfig.
+ Internal function used by load_corpus.
+
+ Args:
+ use_provided_feat_list: If True and feat_config.feat_list is provided,
+ return that list instead of the computed one. Used for test sets
+ to ensure same features as training set.
"""
-
- if feats in ('lemma', 'pos', 'met_line', 'met_syll') and format != 'tei':
- raise ValueError("lemma, pos, met_line or met_syll are only possible with adequate tei format (@lemma, @pos, @met)")
-
- if feats in ('met_line', 'met_syll') and units != 'lines':
- raise ValueError("met_line or met_syll are only possible with tei format that includes lines and @met")
+ feats = feat_config.type
+ n = feat_config.n
+ k = feat_config.k
+ freqsType = feat_config.freq_type
+ provided_feat_list = feat_config.feat_list
+ embedding = feat_config.embedding
+ neighbouring_size = feat_config.neighbouring_size
+ culling = feat_config.culling
embeddedFreqs = False
if embedding:
print(".......loading embedding.......")
- relFreqs = False # we need absolute freqs as a basis for embedded frequencies
model = embed.load_embeddings(embedding)
embeddedFreqs = True
- freqsType = "absolute" #absolute freqs are required for embedding
+ freqsType = "absolute"
- print(".......loading texts.......")
-
- if sampling:
- myTexts = pipe.docs_to_samples(data_paths, feats=feats, format=format, units=units, size=size, step=step,
- max_samples=max_samples, samples_random=samples_random,
- keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii,
- identify_lang = identify_lang)
+ print(f".......getting features ({feats}, n={n}).......")
- else:
- myTexts = pipe.load_texts(data_paths, feats=feats, format=format, max_samples=max_samples, keep_punct=keep_punct,
- keep_sym=keep_sym, no_ascii=no_ascii, identify_lang=identify_lang)
-
- print(".......getting features.......")
-
- if feat_list is None:
+ if provided_feat_list is None:
feat_list = fex.get_feature_list(myTexts, feats=feats, n=n, freqsType=freqsType)
if k > len(feat_list):
- print("K Limit ignored because the size of the list is lower ({} < {})".format(len(feat_list), k))
+ print(f"K limit ignored ({len(feat_list)} < {k})")
else:
- # and now, cut at around rank k
val = feat_list[k-1][1]
feat_list = [m for m in feat_list if m[1] >= val]
-
+ else:
+ feat_list = provided_feat_list
print(".......getting counts.......")
- my_feats = [m[0] for m in feat_list] # keeping only the features without the frequencies
- myTexts = fex.get_counts(myTexts, feat_list=my_feats, feats=feats, n=n, freqsType=freqsType)
+ my_feats = [m[0] for m in feat_list]
+ # Copy myTexts to avoid mutating original for multi-feature
+ texts_copy = [dict(t) for t in myTexts]
+ texts_copy = fex.get_counts(texts_copy, feat_list=my_feats, feats=feats, n=n, freqsType=freqsType)
+
if embedding:
print(".......embedding counts.......")
- myTexts, my_feats = embed.get_embedded_counts(myTexts, my_feats, model, topn=neighbouring_size)
+ texts_copy, my_feats = embed.get_embedded_counts(texts_copy, my_feats, model, topn=neighbouring_size)
feat_list = [f for f in feat_list if f[0] in my_feats]
- unique_texts = [text["name"] for text in myTexts]
-
if culling > 0:
- print(".......Culling at " + str(culling) + "%.......")
- # Counting in how many sample the feat appear
- feats_doc_freq = fex.get_doc_frequency(myTexts)
- # Now selecting
- my_feats = [f for f in my_feats if (feats_doc_freq[f] / len(myTexts) * 100) > culling]
+ print(f".......Culling at {culling}%.......")
+ feats_doc_freq = fex.get_doc_frequency(texts_copy)
+ my_feats = [f for f in my_feats if (feats_doc_freq[f] / len(texts_copy) * 100) > culling]
feat_list = [f for f in feat_list if f[0] in my_feats]
print(".......feeding data frame.......")
loc = {}
-
- for t in tqdm.tqdm(myTexts):
+ for t in tqdm.tqdm(texts_copy):
text, local_freqs = count_process((t, my_feats), embeddedFreqs=embeddedFreqs)
loc[text["name"]] = local_freqs
- # Saving metadata for later
- metadata = pandas.DataFrame(columns=['author', 'lang'], index=unique_texts, data=
- [[t["aut"], t["lang"]] for t in myTexts])
+ feats_df = pandas.DataFrame.from_dict(loc, columns=list(my_feats), orient="index")
- # Free some space before doing this...
- del myTexts
+ # For test sets: return the provided feat_list unchanged
+ if use_provided_feat_list and provided_feat_list is not None:
+ return feats_df, provided_feat_list
+
+ return feats_df, feat_list
- # frequence based selection
- # WOW, pandas is a great tool, almost as good as using R
- # But confusing as well: boolean selection works on rows by default
- # were elsewhere it works on columns
- # take only rows where the number of values above 0 is superior to two
- # (i.e. appears in at least two texts)
- #feats = feats.loc[:, feats[feats > 0].count() > 2]
- feats = pandas.DataFrame.from_dict(loc, columns=list(my_feats), orient="index")
+def load_corpus(
+ config: Optional[Config] = None,
+ use_provided_feat_list: bool = False,
+ **kwargs
+) -> Tuple[pandas.DataFrame, Union[List, List[List]]]:
+ """
+ Load a corpus and extract features.
+
+ Can be called with:
+ 1. A Config object: load_corpus(config=my_config)
+ 2. Individual parameters (backward compatible):
+ load_corpus(data_paths=paths, feats="chars", n=3)
+
+ Args:
+ config: Configuration object. If None, built from kwargs.
+ use_provided_feat_list: If True and feat_list provided, return it unchanged.
+ Use for test sets to match training features.
+ **kwargs: Individual parameters for backward compatibility.
+ Supported: data_paths, feat_list, feats, n, k, freqsType,
+ format, sampling, units, size, step, max_samples, samples_random,
+ keep_punct, keep_sym, no_ascii, identify_lang, embedding,
+ neighbouring_size, culling
+
+ Returns:
+ - If single feature: (DataFrame, feat_list)
+ - If multiple features: (DataFrame with prefixed columns, list of feat_lists)
+ """
+ # Build config from kwargs if not provided
+ if config is None:
+ config = Config.from_kwargs(**kwargs)
+
+ # Validate configuration
+ config.validate()
+ data_paths = config.corpus.paths
+
+ # Handle string paths (single file or glob pattern)
+ if isinstance(data_paths, str):
+ import glob
+ # If it's a glob pattern, expand it
+ if '*' in data_paths or '?' in data_paths:
+ data_paths = sorted(glob.glob(data_paths))
+ else:
+ # Single file path - wrap in list
+ data_paths = [data_paths]
+
+ # Validate
+ for feat_config in config.features:
+ if feat_config.type in ('lemma', 'pos', 'met_line', 'met_syll') and config.corpus.format != 'tei':
+ raise ValueError(f"{feat_config.type} requires TEI format.")
+ if feat_config.type in ('met_line', 'met_syll') and config.sampling.units != 'verses':
+ raise ValueError(f"{feat_config.type} verses lines units.")
+ data_paths = config.corpus.paths
+
+ # Handle string paths (single file or glob pattern)
+ if isinstance(data_paths, str):
+ import glob
+ # If it's a glob pattern, expand it
+ if '*' in data_paths or '?' in data_paths:
+ data_paths = sorted(glob.glob(data_paths))
+ else:
+ # Single file path - wrap in list
+ data_paths = [data_paths]
- # Free some more
- del loc
+ # Validate
+ for feat_config in config.features:
+ if feat_config.type in ('lemma', 'pos', 'met_line', 'met_syll') and config.corpus.format != 'tei':
+ raise ValueError(f"{feat_config.type} requires TEI format.")
+ if feat_config.type in ('met_line', 'met_syll') and config.sampling.units != 'verses':
+ raise ValueError(f"{feat_config.type} requires verses units.")
- corpus = pandas.concat([metadata, feats], axis=1)
+ # Load texts once
+ print(".......loading texts.......")
- return corpus, feat_list
\ No newline at end of file
+ if config.sampling.enabled:
+ myTexts = pipe.docs_to_samples(
+ data_paths,
+ config=config
+ )
+ else:
+ myTexts = pipe.load_texts(
+ data_paths,
+ config=config
+ )
+
+ unique_texts = [text["name"] for text in myTexts]
+
+ # Build metadata
+ metadata = pandas.DataFrame(
+ columns=['author', 'lang'],
+ index=unique_texts,
+ data=[[t["aut"], t["lang"]] for t in myTexts]
+ )
+
+ # Single feature case
+ if len(config.features) == 1:
+ feat_config = config.features[0]
+ feats_df, feat_list = _load_single_feature(
+ myTexts, feat_config, config.normalization, use_provided_feat_list
+ )
+ corpus = pandas.concat([metadata, feats_df], axis=1)
+ return corpus, feat_list
+
+ # Multiple features case
+ print(f".......extracting {len(config.features)} feature sets.......")
+
+ all_feat_lists = []
+ merged_feats = metadata.copy()
+
+ for i, feat_config in enumerate(config.features):
+ prefix = feat_config.name or f"f{i+1}"
+ print(f".......processing {prefix}.......")
+
+ feats_df, feat_list = _load_single_feature(
+ myTexts, feat_config, config.normalization, use_provided_feat_list
+ )
+
+ # Prefix columns to avoid collisions
+ feats_df = feats_df.rename(columns={col: f"{prefix}_{col}" for col in feats_df.columns})
+
+ merged_feats = pandas.concat([merged_feats, feats_df], axis=1)
+ all_feat_lists.append(feat_list)
+
+ return merged_feats, all_feat_lists
\ No newline at end of file
diff --git a/superstyl/load_from_config.py b/superstyl/load_from_config.py
deleted file mode 100644
index 4bc7cca4..00000000
--- a/superstyl/load_from_config.py
+++ /dev/null
@@ -1,197 +0,0 @@
-import json
-import pandas as pd
-import os
-import glob
-
-from superstyl.load import load_corpus
-
-def load_corpus_from_config(config_path, is_test=False):
- """
- Load a corpus based on a JSON configuration file.
-
- Parameters:
- -----------
- config_path : str
- Path to the JSON configuration file
-
- Returns:
- --------
- tuple: (corpus, feat_list) - Same format as load_corpus function
- If multiple features are defined, returns the merged corpus and the combined feature list
- If only one feature is defined, returns that corpus and its feature list
- """
- # Load configuration
- if not config_path.endswith('.json'):
- raise ValueError(f"Unsupported configuration file format: {config_path}. Only JSON format is supported.")
-
- with open(config_path, 'r') as f:
- config = json.load(f)
-
- # Get corpus paths
-
- if 'paths' in config:
- if isinstance(config['paths'], list):
- paths = []
- for path in config['paths']:
- if '*' in path or '?' in path or '[' in path:
- expanded_paths = glob.glob(path)
- if not expanded_paths:
- print(f"Warning: No files found for pattern '{path}'")
- paths.extend(expanded_paths)
- else:
- paths.append(path)
- elif isinstance(config['paths'], str):
- if '*' in config['paths'] or '?' in config['paths'] or '[' in config['paths']:
- paths = glob.glob(config['paths'])
- if not paths:
- raise ValueError(f"No files found for glob pattern '{config['paths']}'")
- else:
- paths = [config['paths']]
- else:
- raise ValueError("Paths in config must be either a list or a glob pattern string")
- else:
- raise ValueError("No paths provided and no paths found in config")
-
- # Get sampling parameters
- sampling_params = config.get('sampling', {})
-
- # Use the first feature to create the base corpus with sampling
- feature_configs = config.get('features', [])
- if not feature_configs:
- raise ValueError("No features specified in the configuration")
-
- # If there's only one feature, we can simply return the result of load_corpus
- if len(feature_configs) == 1:
- feature_config = feature_configs[0]
- feature_name = feature_config.get('name', "f1")
-
- # Check for feature list file
- feat_list = None
- feat_list_path = feature_config.get('feat_list')
- if feat_list_path:
- if feat_list_path.endswith('.json'):
- with open(feat_list_path, 'r') as f:
- feat_list = json.load(f)
- elif feat_list_path.endswith('.txt'):
- with open(feat_list_path, 'r') as f:
- feat_list = [[feat.strip(), 0] for feat in f.readlines()]
-
- # Set up other parameters
- params = {
- 'feats': feature_config.get('type', 'words'),
- 'n': feature_config.get('n', 1),
- 'k': feature_config.get('k', 5000),
- 'freqsType': feature_config.get('freq_type', 'relative'),
- 'format': config.get('format', 'txt'),
- 'sampling': sampling_params.get('enabled', False),
- 'units': sampling_params.get('units', 'words'),
- 'size': sampling_params.get('sample_size', 3000),
- 'step': sampling_params.get('step', None),
- 'max_samples': sampling_params.get('max_samples', None),
- 'samples_random': sampling_params.get('samples_random', False),
- 'keep_punct': feature_config.get('keep_punct', False),
- 'keep_sym': feature_config.get('keep_sym', False),
- 'no_ascii': feature_config.get('no_ascii', False),
- 'identify_lang': feature_config.get('identify_lang', False),
- 'embedding': feature_config.get('embedding', None),
- 'neighbouring_size': feature_config.get('neighbouring_size', 10),
- 'culling': feature_config.get('culling', 0)
- }
-
- print(f"Loading corpus with {feature_name}...")
- corpus, features = load_corpus(paths, feat_list=feat_list, **params)
-
- return corpus, features
-
- # For multiple features, we need to process each one and merge the results
- corpora = {}
- feature_lists = {}
-
- # Process each feature configuration
- for i, feature_config in enumerate(feature_configs):
- feature_name = feature_config.get('name', f"f{i+1}")
-
- # Check for feature list file
- feat_list = None
- feat_list_path = feature_config.get('feat_list')
- print(feat_list_path)
- if feat_list_path:
- if feat_list_path.endswith('.json'):
- with open(feat_list_path, 'r') as f:
- feat_list = json.load(f)
- elif feat_list_path.endswith('.txt'):
- with open(feat_list_path, 'r') as f:
- feat_list = [[feat.strip(), 0] for feat in f.readlines()]
-
- # Set up other parameters
- params = {
- 'feats': feature_config.get('type', 'words'),
- 'n': feature_config.get('n', 1),
- 'k': feature_config.get('k', 5000),
- 'freqsType': feature_config.get('freq_type', 'relative'),
- 'format': config.get('format', 'txt'),
- 'sampling': sampling_params.get('enabled', False),
- 'units': sampling_params.get('units', 'words'),
- 'size': sampling_params.get('sample_size', 3000),
- 'step': sampling_params.get('step', None),
- 'max_samples': sampling_params.get('max_samples', None),
- 'samples_random': sampling_params.get('samples_random', False),
- 'keep_punct': config.get('keep_punct', False),
- 'keep_sym': config.get('keep_sym', False),
- 'no_ascii': config.get('no_ascii', False),
- 'identify_lang': config.get('identify_lang', False),
- 'embedding': feature_config.get('embedding', None),
- 'neighbouring_size': feature_config.get('neighbouring_size', 10),
- 'culling': feature_config.get('culling', 0)
- }
-
- print(f"Loading {feature_name}...")
-
- corpus, features = load_corpus(paths, feat_list=feat_list, **params)
-
- # Store corpus and features
- corpora[feature_name] = corpus
-
- if feat_list is not None and is_test:
- feature_lists[feature_name] = feat_list
- else:
- feature_lists[feature_name] = features
-
-
- # Create a merged dataset
- print("Creating merged dataset...")
- first_corpus_name = next(iter(corpora))
-
- # Start with metadata from the first corpus
- metadata = corpora[first_corpus_name][['author', 'lang']]
-
- # Create an empty DataFrame for the merged corpus
- merged = pd.DataFrame(index=metadata.index)
-
- # Add metadata
- merged = pd.concat([metadata, merged], axis=1)
-
- # Combine all features with prefixes to avoid name collisions
- all_features = []
-
- # Add features from each corpus
- for name, corpus in corpora.items():
- single_feature = []
-
- feature_cols = [col for col in corpus.columns if col not in ['author', 'lang']]
-
- # Rename columns to avoid duplicates
- renamed_cols = {col: f"{name}_{col}" for col in feature_cols}
- feature_df = corpus[feature_cols].rename(columns=renamed_cols)
-
- # Merge with the main DataFrame
- merged = pd.concat([merged, feature_df], axis=1)
-
- # Add features to the combined list with prefixes
- for feature in feature_lists[name]:
- single_feature.append((feature[0], feature[1]))
-
- all_features.append(single_feature)
- # Return the merged corpus and combined feature list
- return merged, all_features
-
diff --git a/superstyl/preproc/pipe.py b/superstyl/preproc/pipe.py
index 7c675b68..3812a43e 100755
--- a/superstyl/preproc/pipe.py
+++ b/superstyl/preproc/pipe.py
@@ -1,145 +1,53 @@
from lxml import etree
-import regex as re
-import unidecode
import nltk.tokenize
import random
-import langdetect
-import unicodedata
-
-def XML_to_text(path):
- """
- Get main text from xml file
- :param path: path to the file to transform
- :return: a tuple with auts, and string (the text).
- """
-
- myxsl = etree.XML('''
-
+from typing import List, Dict, Optional, Tuple
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+
+from superstyl.config import Config, NormalizationConfig, SamplingConfig
+from superstyl.preproc.utils import *
+
+
+# ============================================================================
+# Constants and Configuration
+# ============================================================================
+
+XSLT_TEMPLATES = {
+ 'xml_text': '''
+
+
+
+
+
+ ''',
-
-
-
-
-
-
-''')
- myxsl = etree.XSLT(myxsl)
-
- with open(path, 'r') as f:
- my_doc = etree.parse(f)
-
- auts = my_doc.findall("//author")
- auts = [a.text for a in auts]
-
- if not len(auts) == 1:
- print("Error: more or less than one author in" + path)
-
- if len(auts) == 0:
- auts = [None]
-
- if auts == [None]:
- aut = "unknown"
-
- else:
- aut = auts[0]
-
- return aut, re.sub(r"\s+", " ", str(myxsl(my_doc)))
-
-
-def txm_to_units(path, units="lines", feats="words"):
- """
- Extract units from TXM file
- :param path: path to TXM file
- :param units: units to extract ("lines"/"verses" or "words")
- :param feats: features to extract ("words", "lemma", or "pos")
- :return: list of extracted units
- """
- myxsl = etree.XML('''
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-''')
- myxsl = etree.XSLT(myxsl)
-
- with open(path, 'r') as f:
- my_doc = etree.parse(f)
-
- units_tokens = str(myxsl(my_doc, units=etree.XSLT.strparam(units), feats=etree.XSLT.strparam(feats))).splitlines()
- return units_tokens
-
-def tei_to_units(path, feats="words", units="lines"):
-
- if feats in ["met_syll", "met_line"]:
- feats = "met"
- myxsl = etree.XML('''
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+ 'tei_units': '''
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -148,286 +56,419 @@ def tei_to_units(path, feats="words", units="lines"):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
-
-
-
- ''')
- myxsl = etree.XSLT(myxsl)
-
- with open(path, 'r') as f:
- my_doc = etree.parse(f)
-
- units_tokens = str(myxsl(my_doc, units=etree.XSLT.strparam(units), feats=etree.XSLT.strparam(feats))).splitlines()
- return units_tokens
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ''',
+
+ 'txm_units': '''
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ '''
+}
+
+
+
+
+class FileLoader(ABC):
+ """Abstract base class for file loaders."""
+
+ @abstractmethod
+ def load(self, path: str, **kwargs) -> Tuple[str, str]:
+ """Load file and return (author, text) tuple."""
+ pass
-def specialXML_to_text(path, format="tei", feats="words"):
- aut = path.split('/')[-1].split("_")[0]
- if format=="tei":
- units_tokens = tei_to_units(path, feats=feats, units="words")
- if format=="txm":
- units_tokens = txm_to_units(path, feats=feats, units="words")
+class TXTLoader(FileLoader):
+ """Loader for plain text files."""
+
+ def load(self, path: str, **kwargs) -> Tuple[str, str]:
+ with open(path, 'r') as f:
+ text = ' '.join(f.readlines())
+
+ author = extract_author_from_path(path)
+ return author, normalize_whitespace(text)
- return aut, re.sub(r"\s+", " ", str(' '.join(units_tokens)))
-def TXT_to_text(path):
- """
- Get main text from xml file
- :param path: path to the file to transform
- :return: a tuple with auts, and string (the text).
- """
+class XMLLoader(FileLoader):
+ """Loader for XML files."""
+
+ def __init__(self):
+ self.xslt = etree.XSLT(etree.XML(XSLT_TEMPLATES['xml_text']))
+
+ def load(self, path: str, **kwargs) -> Tuple[str, str]:
+ with open(path, 'r') as f:
+ doc = etree.parse(f)
+
+ authors = doc.findall("//author")
+ author_texts = [a.text for a in authors]
+
+ if len(author_texts) != 1:
+ print(f"Warning: Expected 1 author in {path}, found {len(author_texts)}")
+ author = author_texts[0] if author_texts else "unknown"
+ else:
+ author = author_texts[0]
+
+ text = str(self.xslt(doc))
+ return author, normalize_whitespace(text)
- with open(path, 'r') as f:
- #txt = [line.rstrip() for line in f if line.rstrip() != '']
- txt = f.readlines()
- # get author from filename (string before first _)
- aut = path.split('/')[-1].split("_")[0]
+class XMLUnitLoader(ABC):
+ """Base class for XML loaders that extract units."""
+
+ def __init__(self, template_name: str):
+ self.xslt = etree.XSLT(etree.XML(XSLT_TEMPLATES[template_name]))
+
+ def extract_units(self, path: str, units: str = "verses",
+ feats: str = "words") -> List[str]:
+ """Extract units from XML file."""
+ with open(path, 'r') as f:
+ doc = etree.parse(f)
+
+ params = self._get_xslt_params(units, feats)
+ result = str(self.xslt(doc, **params))
+ return result.splitlines()
+
+ @abstractmethod
+ def _get_xslt_params(self, units: str, feats: str) -> Dict:
+ """Get XSLT parameters for transformation."""
+ pass
- return aut, re.sub(r"\s+", " ", str(' '.join(txt)))
+class TEIUnitLoader(XMLUnitLoader):
+ """Loader for TEI files with unit extraction."""
+
+ def __init__(self):
+ super().__init__('tei_units')
+
+ def _get_xslt_params(self, units: str, feats: str) -> Dict:
+ feats_param = "met" if feats in ["met_syll", "met_line"] else feats
+ return {
+ 'units': etree.XSLT.strparam(units),
+ 'feats': etree.XSLT.strparam(feats_param)
+ }
+
+ def load(self, path: str, feats: str = "words", **kwargs) -> Tuple[str, str]:
+ """Load TEI file and return (author, text) tuple."""
+ author = extract_author_from_path(path)
+ units = self.extract_units(path, units="words", feats=feats)
+ text = normalize_whitespace(' '.join(units))
+ return author, text
-def detect_lang(string):
- """
- Get the language from a string
- :param string: a string, duh
- :return: the language
- """
- return langdetect.detect(string) # , k = 3)
+class TXMUnitLoader(XMLUnitLoader):
+ """Loader for TXM files with unit extraction."""
+
+ def __init__(self):
+ super().__init__('txm_units')
+
+ def _get_xslt_params(self, units: str, feats: str) -> Dict:
+ return {
+ 'units': etree.XSLT.strparam(units),
+ 'feats': etree.XSLT.strparam(feats)
+ }
+
+ def load(self, path: str, feats: str = "words", **kwargs) -> Tuple[str, str]:
+ """Load TXM file and return (author, text) tuple."""
+ author = extract_author_from_path(path)
+ units = self.extract_units(path, units="words", feats=feats)
+ text = normalize_whitespace(' '.join(units))
+ return author, text
-def normalise(text, keep_punct=False, keep_sym=False, no_ascii=False):
- """
- Function to normalise an input string. By defaults, it removes all but word chars, remove accents,
- and normalise space, and then normalise unicode.
- :param keep_punct: if true, in addition, also keeps Punctuation and case distinction
- :param keep_sym: if true, same as keep_punct, but keeps also N?umbers, Symbols, Marks, such as combining diacritics,
- as well as Private use characters, and no Unidecode is applied
- :param no_ascii: disables conversion to ascii
- """
- # Remove all but word chars, remove accents, and normalise space
- # and then normalise unicode
+# Loader factory
+LOADERS = {
+ 'txt': TXTLoader(),
+ 'xml': XMLLoader(),
+ 'tei': TEIUnitLoader(),
+ 'txm': TXMUnitLoader()
+}
- if keep_sym:
- out = re.sub(r"[^\p{L}\p{P}\p{N}\p{S}\p{M}\p{Co}]+", " ", text)
- else:
- if keep_punct:
- # Keep punctuation (and diacritics for now)
- out = re.sub(r"[^\p{L}\p{P}\p{M}]+", " ", text)
+def XML_to_text(path: str) -> Tuple[str, str]:
+ """Legacy function for XML loading."""
+ return LOADERS['xml'].load(path)
- else:
- #out = re.sub(r"[\W0-9]+", " ", text.lower())
- out = re.sub(r"[^\p{L}\p{M}]+", " ", text.lower())
- if no_ascii is not True:
- out = unidecode.unidecode(out)
+def TXT_to_text(path: str) -> Tuple[str, str]:
+ """Legacy function for TXT loading."""
+ return LOADERS['txt'].load(path)
- # Normalise unicode
- out = unicodedata.normalize("NFC", out)
- out = re.sub(r"\s+", " ", out).strip()
+def tei_to_units(path: str, feats: str = "words", units: str = "verses") -> List[str]:
+ """Legacy function for TEI unit extraction."""
+ return LOADERS['tei'].extract_units(path, units, feats)
- return out
-def max_sampling(myTexts, max_samples=10):
- """
- Select a random number of samples, equal to max_samples, for authors or classes that have more than max_samples
- :param myTexts: the input myTexts object
- :param max_samples: the maximum number of samples for any class
- :return: a myTexts object, with the resulting selection of samples
- """
- autsCounts = dict()
- for text in myTexts:
- if text['aut'] not in autsCounts.keys():
- autsCounts[text['aut']] = 1
+def txm_to_units(path: str, units: str = "verses", feats: str = "words") -> List[str]:
+ """Legacy function for TXM unit extraction."""
+ return LOADERS['txm'].extract_units(path, units, feats)
- else:
- autsCounts[text['aut']] += 1
- for autCount in autsCounts.items():
- if autCount[1] > max_samples:
- # get random selection
- toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]]
- toBeSelected = random.sample(toBeSelected, k=max_samples)
- # Great, now remove all texts from this author from our samples
- myTexts = [text for text in myTexts if text['aut'] != autCount[0]]
- # and now concat
- myTexts = myTexts + toBeSelected
+def specialXML_to_text(path: str, format: str = "tei", feats: str = "words") -> Tuple[str, str]:
+ """Legacy function for special XML loading."""
+ return LOADERS[format].load(path, feats=feats)
- return myTexts
+# ============================================================================
+# Sampling Functions
+# ============================================================================
-def load_texts(paths, identify_lang=False, feats="words", format="txt", keep_punct=False, keep_sym=False, no_ascii=False,
- max_samples=None):
+class Sampler:
"""
- Loads a collection of documents into a 'myTexts' object for further processing.
- TODO: a proper class
- :param paths: path to docs
- :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'.
- :param identify_lang: whether or not try to identify lang (default: False)
- :param format: format of the source files (implemented values: txt [default], xml)
- :param keep_punct: whether or not to keep punctuation and caps.
- :param keep_sym: whether or not to keep punctuation, caps, letter variants and numbers (no unidecode).
- :param no_ascii: disables conversion to ascii
- :param max_samples: the maximum number of samples for any class
- :return: a myTexts object
+ Handles text sampling operations.
"""
-
- myTexts = []
-
- for path in paths:
- name = path.split('/')[-1]
-
- if format=='xml':
- aut, text = XML_to_text(path)
-
- if format in ('tei', 'txm'):
- aut, text = specialXML_to_text(path, format=format, feats=feats)
-
+
+ @staticmethod
+ def extract_tokens(path: str, config: Config=Config()) -> List[str]:
+ """
+ Extract tokens from a document based on format and units.
+ """
+ feats=config.features[0].type
+
+ if config.sampling.units == "words" and config.corpus.format == "txt":
+ author, text = LOADERS['txt'].load(path)
+ text = normalise(text, config.normalization)
+ return nltk.tokenize.wordpunct_tokenize(text)
+
+ elif config.corpus.format == "tei":
+ return LOADERS['tei'].extract_units(path, config.corpus.units, feats)
+
+ elif config.sampling.units == "verses" and config.corpus.format == "txm":
+ return LOADERS['txm'].extract_units(path, config.sampling.units, feats)
+
else:
- aut, text = TXT_to_text(path)
-
- if identify_lang:
- lang = detect_lang(text)
+ raise ValueError(f"Unsupported combination: units={config.sampling.units}, format={config.corpus.format}")
+
+ @staticmethod
+ def create_samples(tokens: List[str], sampling_config: SamplingConfig=SamplingConfig()) -> List[Dict]:
+ """
+ Create samples from tokens.
+ """
+ step = sampling_config.step if sampling_config.step is not None else sampling_config.size
+
+ samples = []
+
+ if sampling_config.random:
+ for k in range(sampling_config.max_samples):
+ samples.append({
+ "start": f"{k}s",
+ "end": f"{k}e",
+ "text": list(random.choices(tokens, k=sampling_config.size))
+ })
else:
- lang = "NA"
-
- # Normalise text once and for all
- text = normalise(text, keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii)
-
- myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang})
-
- if max_samples is not None:
- myTexts = max_sampling(myTexts, max_samples=max_samples)
-
- return myTexts
-
-
-# Load and split in samples of length -n- a collection of files
-def get_samples(path, size, step=None, samples_random=False, max_samples=10,
- units="words", format="txt", feats="words", keep_punct=False, keep_sym=False, no_ascii=False):
+ current = 0
+ while current + sampling_config.size <= len(tokens):
+ samples.append({
+ "start": current,
+ "end": current + sampling_config.size,
+ "text": list(tokens[current:current + sampling_config.size])
+ })
+ current += step
+
+ return samples
+
+ @classmethod
+ def get_samples(cls, path: str, config: Config=Config()) -> List[Dict]:
+ """
+ Extract samples from a document.
+
+ Args:
+ path: Path to document
+ config: Config file
+
+ Returns:
+ List of sample dictionaries
+ """
+ max_samples = config.sampling.max_samples or 10
+ config.sampling.validate()
+
+ tokens = cls.extract_tokens(path, config)
+ return cls.create_samples(tokens, config.sampling)
+
+
+def max_sampling(documents: List[Dict], max_samples: int = 10) -> List[Dict]:
"""
- Take samples of n words or verses from a document, and then parse it.
- TODO: ONLY IMPLEMENTED FOR NOW: XML/TEI, TXT and verses or words as units
- :param path : path to file
- :param size: sample size
- :param step: size of the step when sampling successively (determines overlap) default is the same
- as sample size (i.e. no overlap)
- :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false)
- :param max_samples: maximum number of samples per author/clas
- :param units: the units to use, one of "words" or "verses"
- :param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED)
- :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'.
+ Randomly select up to max_samples per author/class.
+
+ Args:
+ documents: List of text dict
+ max_samples: Maximum samples per author
+
+ Returns:
+ Filtered list of documents
"""
+ # Count documents per author
+ author_counts = {}
+ for doc in documents:
+ author_counts[doc['aut']] = author_counts.get(doc['aut'], 0) + 1
+
+ # Filter authors with too many samples
+ result = []
+ for author, count in author_counts.items():
+ author_docs = [d for d in documents if d['aut'] == author]
- if samples_random and step is not None:
- raise ValueError("random sampling is not compatible with continuous sampling (remove either the step or the samples_random argument")
-
- if samples_random and not max_samples:
- raise ValueError("random sampling needs a fixed number of samples (use the max_samples argument)")
-
- if step is None:
- step = size
-
- if units == "words" and format == "txt":
- my_doc = TXT_to_text(path)
- text = normalise(my_doc[1], keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii)
- units_tokens = nltk.tokenize.wordpunct_tokenize(text)
-
- #Kept only for retrocompatibility with Psysché
- if units == "verses" and format == "txm":
- units_tokens = txm_to_units(path, units=units)
-
- if format == "tei":
- units_tokens = tei_to_units(path, units=units, feats=feats)
+ if count > max_samples:
+ result.extend(random.sample(author_docs, k=max_samples))
+ else:
+ result.extend(author_docs)
+
+ return result
- # and now generating output
- samples = []
- if samples_random:
- for k in range(max_samples):
- samples.append({"start": str(k)+'s', "end": str(k)+'e', "text": list(random.choices(units_tokens, k=size))})
+# ============================================================================
+# Main Loading Functions
+# ============================================================================
- else:
- current = 0
- while current + size <= len(units_tokens):
- samples.append({"start": current, "end": current + size, "text": list(units_tokens[current:(current + size)])})
- current = current + step
+def load_texts(paths: List[str], config: Config=Config()) -> List[Dict]:
+ """
+ Load a collection of documents.
+
+ Args:
+ paths: List of file paths
+ config: Config file
+
+ Returns:
+ List of document dictionaries
+ """
+ loader = LOADERS.get(config.corpus.format)
+ if not loader:
+ raise ValueError(f"Unsupported format: {config.corpus.format}")
+
+ documents = []
+ feats=config.features[0].type
- return samples
+ for path in paths:
+ name = path.split('/')[-1]
+ author, text = loader.load(path, feats=feats)
+
+ lang = detect_lang(text) if config.corpus.identify_lang else "NA"
+
+ # Normalize text
+ text = normalise(text, config.normalization)
+
+ documents.append({
+ "name": name,
+ "aut": author,
+ "text": text,
+ "lang": lang
+ })
+
+ if config.sampling.max_samples is not None:
+ documents = max_sampling(documents, config.sampling.max_samples)
+
+ return documents
-def docs_to_samples(paths, size, step=None, units="words", samples_random=False, format="txt", feats="words", keep_punct=False,
- keep_sym=False, no_ascii=False, max_samples=None, identify_lang=False):
+def docs_to_samples(paths: List[str], config: Config=Config()) -> List[Dict]:
"""
- Loads a collection of documents into a 'myTexts' object for further processing BUT with samples !
- :param paths: path to docs
- :param size: sample size
- :param step: size of the step when sampling successively (determines overlap) default is the same
- as sample size (i.e. no overlap)
- :param units: the units to use, one of "words" or "verses"
- :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false)
- :param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED)
- :param keep_punct: whether to keep punctuation and caps.
- :param max_samples: maximum number of samples per author/class.
- :param identify_lang: whether to try to identify lang (default: False)
- :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'.
- :return: a myTexts object
+ Load documents with sampling.
+
+ Args:
+ paths: List of file paths
+ config: Config file
+
+ Returns:
+ List of sample dictionaries
"""
- myTexts = []
- for path in paths:
- aut = path.split('/')[-1].split('_')[0]
- if identify_lang:
- if format == 'xml':
- aut, text = XML_to_text(path)
-
- else:
- aut, text = TXT_to_text(path)
+ loader = LOADERS.get(config.corpus.format)
+ if not loader:
+ raise ValueError(f"Unsupported format: {config.corpus.format}")
+
+ all_samples = []
+ feats=config.features[0].type
+ for path in paths:
+ author = extract_author_from_path(path)
+
+ # Detect language if needed
+ if config.corpus.identify_lang:
+ _, text = loader.load(path, feats=feats)
lang = detect_lang(text)
-
else:
lang = 'NA'
-
- samples = get_samples(path, size=size, step=step, samples_random=samples_random, max_samples=max_samples,
- units=units, format=format, feats=feats,
- keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii)
-
+
+ # Get samples
+ samples = Sampler.get_samples(path, config)
+
+ # Create sample documents
for sample in samples:
- name = path.split('/')[-1] + '_' + str(sample["start"]) + "-" + str(sample["end"])
- text = normalise(' '.join(sample["text"]), keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii)
- myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang})
-
- if max_samples is not None:
- myTexts = max_sampling(myTexts, max_samples=max_samples)
-
- return myTexts
+ name = f"{path.split('/')[-1]}_{sample['start']}-{sample['end']}"
+ text = normalise(' '.join(sample['text']), config.normalization)
+
+ all_samples.append({
+ "name": name,
+ "aut": author,
+ "text": text,
+ "lang": lang
+ })
+
+ if config.sampling.max_samples is not None:
+ all_samples = max_sampling(all_samples, config.sampling.max_samples)
+
+ return all_samples
\ No newline at end of file
diff --git a/superstyl/preproc/select.py b/superstyl/preproc/select.py
index d21efb47..fb887b5a 100755
--- a/superstyl/preproc/select.py
+++ b/superstyl/preproc/select.py
@@ -2,151 +2,135 @@
import csv
import random
import json
+from typing import Optional, List, Tuple
-# TODO: make same modifs for the no split
-def read_clean_split(path, metadata_path=None, excludes_path=None, savesplit=None, lang=None):
+def _load_metadata(path: str, metadata_path: Optional[str],
+ excludes_path: Optional[str], lang: Optional[str]) -> Tuple[Optional[pandas.DataFrame], Optional[List[str]]]:
"""
- Function to read a csv, clean it, and then split it in train and dev,
- either randomly or according to a preexisting selection
- :param path: path to csv file
- :param metadata_path: path to metadata file
- :param excludes_path: path to file with list of excludes
- :param presplit: path to file with preexisting split (optional)
- :param savesplit: path to save split (optional)
- :return: saves to disk
+ Load metadata and exclusion list if needed.
"""
-
- trainf = open(path.split(".")[0] + "_train.csv", 'w')
- validf = open(path.split(".")[0] + "_valid.csv", 'w')
-
- selection = {'train': [], 'valid': [], 'elim': []}
-
- # Do we need to create metadata ?
+ metadata = None
+ excludes = None
+
if metadata_path is None and (excludes_path is not None or lang is not None):
- metadata = pandas.read_csv(path)
- metadata = pandas.DataFrame(index=metadata.loc[:, "Unnamed: 0"], columns=['lang'], data=list(metadata.loc[:, "lang"]))
-
- if metadata_path is not None:
- metadata = pandas.read_csv(metadata_path)
- metadata = pandas.DataFrame(index=metadata.loc[:, "id"], columns=['lang'], data=list(metadata.loc[:, "true"]))
-
+ data = pandas.read_csv(path)
+ metadata = pandas.DataFrame(
+ index=data.loc[:, "Unnamed: 0"],
+ columns=['lang'],
+ data=list(data.loc[:, "lang"])
+ )
+ elif metadata_path is not None:
+ data = pandas.read_csv(metadata_path)
+ metadata = pandas.DataFrame(
+ index=data.loc[:, "id"],
+ columns=['lang'],
+ data=list(data.loc[:, "true"])
+ )
+
if excludes_path is not None:
- excludes = pandas.read_csv(excludes_path)
- excludes = list(excludes.iloc[:, 0])
-
- with open(path, "r") as f:
- head = f.readline()
- trainf.write(head)
- validf.write(head)
-
- # and prepare to write csv lines to them
- train = csv.writer(trainf)
- valid = csv.writer(validf)
-
- print("....evaluating each text.....")
-
- reader = csv.reader(f, delimiter=",")
-
- for line in reader:
-
- # checks
- if lang is not None:
- # First check if good language
- if not metadata.loc[line[0], "lang"] == lang:
- selection['elim'].append(line[0])
- print("not in: " + lang + " " + line[0])
- # if not, eliminate it, and go to next line
- continue
-
- if excludes_path is not None:
- # then check if to exclude
- if line[0] in excludes:
- selection['elim'].append(line[0])
- print("Is a Wilhelmus instance! : " + line[0])
- # then eliminate it, and go to next line
- continue
-
- # Now that we have only the good lines, proceed to split
-
- # 10% for dev
- if random.randint(1, 10) == 1:
- selection['valid'].append(line[0])
- valid.writerow(line)
-
- # 90% for train
- else:
- selection['train'].append(line[0])
- train.writerow(line)
+ excludes_data = pandas.read_csv(excludes_path)
+ excludes = list(excludes_data.iloc[:, 0])
+
+ return metadata, excludes
- trainf.close()
- validf.close()
- with open(savesplit, "w") as out:
- out.write(json.dumps(selection))
-# TODO: merge this one and the previous ?
-def read_clean(path, metadata_path=None, excludes_path=None, savesplit=None, lang=None):
+def _should_exclude(line_id: str, metadata: Optional[pandas.DataFrame],
+ excludes: Optional[List[str]], lang: Optional[str]) -> Tuple[bool, Optional[str]]:
"""
- Function to read a csv, clean it.
- :param path: path to csv file
- :param metadata_path: path to metadata file
- :param excludes_path: path to file with list of excludes
+ Determine if a line should be excluded.
+ """
+ if lang is not None and metadata is not None:
+ try:
+ if metadata.loc[line_id, "lang"] != lang:
+ return True, f"not in: {lang} {line_id}"
+ except KeyError:
+ pass
+
+ if excludes is not None and line_id in excludes:
+ return True, f"Is a Wilhelmus instance! : {line_id}"
+
+ return False, None
+
+
+def read_clean(path: str, metadata_path: Optional[str] = None,
+ excludes_path: Optional[str] = None,
+ savesplit: Optional[str] = None,
+ lang: Optional[str] = None,
+ split: bool = False,
+ split_ratio: float = 0.1) -> None:
+ """
+ Read a CSV, clean it, and optionally split it into train and validation sets.
+
+ :param path: path to CSV file
+ :param metadata_path: path to metadata file (optional)
+ :param excludes_path: path to file with list of excludes (optional)
+ :param savesplit: path to save selection JSON (optional)
+ :param lang: only include texts in this language (optional)
+ :param split: if True, split into train/valid sets (default False)
+ :param split_ratio: ratio for validation set when split=True (default 0.1 = 10%)
:return: saves to disk
"""
-
- trainf = open(path.split(".")[0] + "_selected.csv", 'w')
-
+ metadata, excludes = _load_metadata(path, metadata_path, excludes_path, lang)
+
+ base_path = path.rsplit(".", 1)[0]
+
+ # Initialize selection tracking
selection = {'train': [], 'elim': []}
-
- # Do we need to create metadata ?
- if metadata_path is None and (excludes_path is not None or lang is not None):
- metadata = pandas.read_csv(path)
- metadata = pandas.DataFrame(index=metadata.loc[:, "Unnamed: 0"], columns=['lang'], data=list(metadata.loc[:, "lang"]))
-
- if metadata_path is not None:
- metadata = pandas.read_csv(metadata_path)
- metadata = pandas.DataFrame(index=metadata.loc[:, "id"], columns=['lang'], data=list(metadata.loc[:, "true"]))
-
- if excludes_path is not None:
- excludes = pandas.read_csv(excludes_path)
- excludes = list(excludes.iloc[:, 0])
-
- with open(path, "r") as f:
- head = f.readline()
- trainf.write(head)
-
- # and prepare to write csv lines to them
- train = csv.writer(trainf)
-
- print("....evaluating each text.....")
-
- reader = csv.reader(f, delimiter=",")
-
- for line in reader:
-
- # checks
- if lang is not None:
- # First check if good language
- if not metadata.loc[line[0], "lang"] == lang:
- selection['elim'].append(line[0])
- print("not in: " + lang + " " + line[0])
- # if not, eliminate it, and go to next line
- continue
-
- if excludes_path is not None:
- # then check if to exclude
- if line[0] in excludes:
- selection['elim'].append(line[0])
- print("Is a Wilhelmus instance! : " + line[0])
- # then eliminate it, and go to next line
+ if split:
+ selection['valid'] = []
+
+ # Open output files
+ if split:
+ train_path = f"{base_path}_train.csv"
+ valid_path = f"{base_path}_valid.csv"
+ trainf = open(train_path, 'w')
+ validf = open(valid_path, 'w')
+ else:
+ train_path = f"{base_path}_selected.csv"
+ trainf = open(train_path, 'w')
+ validf = None
+
+ try:
+ with open(path, "r") as f:
+ header = f.readline()
+ trainf.write(header)
+ if validf:
+ validf.write(header)
+
+ train_writer = csv.writer(trainf)
+ valid_writer = csv.writer(validf) if validf else None
+
+ print("....evaluating each text.....")
+ reader = csv.reader(f, delimiter=",")
+
+ for line in reader:
+ line_id = line[0]
+
+ # Check exclusion
+ is_excluded, reason = _should_exclude(line_id, metadata, excludes, lang)
+ if is_excluded:
+ selection['elim'].append(line_id)
+ if reason:
+ print(reason)
continue
-
- # Now that we have only the good lines, proceed to write
- selection['train'].append(line[0])
- train.writerow(line)
-
- trainf.close()
- with open(savesplit, "w") as out:
- out.write(json.dumps(selection))
+
+ # Route to train or valid
+ if split and random.random() < split_ratio:
+ selection['valid'].append(line_id)
+ valid_writer.writerow(line)
+ else:
+ selection['train'].append(line_id)
+ train_writer.writerow(line)
+
+ finally:
+ trainf.close()
+ if validf:
+ validf.close()
+
+ # Save selection
+ if savesplit:
+ with open(savesplit, "w") as out:
+ out.write(json.dumps(selection))
def apply_selection(path, presplit_path):
"""
diff --git a/superstyl/preproc/utils.py b/superstyl/preproc/utils.py
new file mode 100644
index 00000000..5de411d1
--- /dev/null
+++ b/superstyl/preproc/utils.py
@@ -0,0 +1,50 @@
+import unidecode
+import langdetect
+import unicodedata
+import regex as re
+
+from superstyl.config import NormalizationConfig
+
+
+
+def extract_author_from_path(path: str) -> str:
+ """Extract author from file path (before first underscore)."""
+ return path.split('/')[-1].split("_")[0]
+
+
+def normalize_whitespace(text: str) -> str:
+ """Normalize whitespace in text."""
+ return re.sub(r"\s+", " ", text).strip()
+
+
+def detect_lang(text: str) -> str:
+ """Detect language of text using langdetect."""
+ return langdetect.detect(text)
+
+
+def normalise(text: str, norm_config : NormalizationConfig=NormalizationConfig()) -> str:
+ """
+ Normalize input text according to specified options.
+
+ Args:
+ text: Input text to normalize
+ keep_punct: Keep punctuation and case distinction
+ keep_sym: Keep punctuation, case, numbers, symbols, marks
+ no_ascii: Disable conversion to ASCII
+
+ Returns:
+ Normalized text
+ """
+ if norm_config.keep_sym:
+ out = re.sub(r"[^\p{L}\p{P}\p{N}\p{S}\p{M}\p{Co}]+", " ", text)
+ else:
+ if norm_config.keep_punct:
+ out = re.sub(r"[^\p{L}\p{P}\p{M}]+", " ", text)
+ else:
+ out = re.sub(r"[^\p{L}\p{M}]+", " ", text.lower())
+
+ if not norm_config.no_ascii:
+ out = unidecode.unidecode(out)
+
+ out = unicodedata.normalize("NFC", out)
+ return normalize_whitespace(out)
\ No newline at end of file
diff --git a/superstyl/svm.py b/superstyl/svm.py
index abceb823..74ce13d5 100755
--- a/superstyl/svm.py
+++ b/superstyl/svm.py
@@ -12,45 +12,53 @@
import imblearn.combine as comb
import imblearn.pipeline as imbp
from collections import Counter
+from typing import Optional, Dict, Any
+from superstyl.config import Config
-def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, balance=None, class_weights=False,
- kernel="LinearSVC",
- final_pred=False, get_coefs=False):
+
+def train_svm(
+ train: pandas.DataFrame,
+ test: Optional[pandas.DataFrame] = None,
+ config: Optional[Config] = None,
+ **kwargs
+) -> Dict[str, Any]:
"""
- Function to train svm
- :param train: train data... (in panda dataframe)
- :param test: test data (itou)
- :param cross_validate: whether to perform cross validation (possible values: leave-one-out, k-fold
- and group-k-fold) if group_k-fold is chosen, each source file will be considered a group, so this is only relevant
- if sampling was performed and more than one file per class was provided
- :param k: k parameter for k-fold cross validation
- :param dim_reduc: dimensionality reduction of input data. Implemented values are pca and som.
- :param norms: perform normalisations, i.e. z-scores and L2 (default True)
- :param balance: up/downsampling strategy to use in imbalanced datasets
- :param class_weights: adjust class weights to balance imbalanced datasets, with weights inversely proportional to class
- frequencies in the input data as n_samples / (n_classes * np.bincount(y))
- :param kernel: kernel for SVM
- :param final_pred: do the final predictions?
- :param get_coefs, if true, writes to disk (coefficients.csv) and plots the most important coefficients for each class
- :return: prints the scores, and then returns a dictionary containing the pipeline with a fitted svm model,
- and, if computed, the classification_report, confusion_matrix, list of misattributions, and final_predictions.
+ Train SVM model for stylometric analysis.
+
+ Can be called with:
+ 1. A Config object: train_svm(train, test, config=my_config)
+ 2. Individual parameters (backward compatible):
+ train_svm(train, test, cross_validate="k-fold", k=10)
+
+ Args:
+ train: Training data (pandas DataFrame)
+ test: Test data (optional)
+ config: Configuration object. If None, built from kwargs.
+ **kwargs: Individual parameters for backward compatibility.
+ Supported: cross_validate, k, dim_reduc, norms, balance,
+ class_weights, kernel, final_pred, get_coefs
+
+ Returns:
+ Dictionary containing: pipeline, and optionally confusion_matrix,
+ classification_report, misattributions, final_predictions, coefficients
"""
- results = {}
+ # Build config from kwargs if not provided
+ if config is None:
+ config = Config.from_kwargs(**kwargs)
+
+ # Extract SVM parameters from config
+ cross_validate = config.svm.cross_validate
+ k = config.svm.k
+ dim_reduc = config.svm.dim_reduc
+ norms = config.svm.norms
+ balance = config.svm.balance
+ class_weights = config.svm.class_weights
+ kernel = config.svm.kernel
+ final_pred = config.svm.final_pred
+ get_coefs = config.svm.get_coefs
- valid_cross_validate_options = {None, "leave-one-out", "k-fold", 'group-k-fold'}
- valid_dim_reduc_options = {None, 'pca'}
- valid_balance_options = {None, 'downsampling', 'upsampling', 'Tomek', 'SMOTE', 'SMOTETomek'}
- # Validate parameters
- if cross_validate not in valid_cross_validate_options:
- raise ValueError(
- f"Invalid cross-validation option: '{cross_validate}'. Valid options are {valid_cross_validate_options}.")
- if dim_reduc not in valid_dim_reduc_options:
- raise ValueError(
- f"Invalid dimensionality reduction option: '{dim_reduc}'. Valid options are {valid_dim_reduc_options}.")
- # Validate 'balance' parameter
- if balance not in valid_balance_options:
- raise ValueError(f"Invalid balance option: '{balance}'. Valid options are {valid_balance_options}.")
+ results = {}
print(".......... Formatting data ........")
# Save the classes
@@ -72,28 +80,19 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
if dim_reduc == 'pca':
print(".......... using PCA ........")
- estimators.append(('dim_reduc', decomp.PCA())) # chosen with default
- # which is: n_components = min(n_samples, n_features)
+ estimators.append(('dim_reduc', decomp.PCA()))
+
if norms:
- # Z-scores
print(".......... using normalisations ........")
estimators.append(('scaler', preproc.StandardScaler()))
- # NB: j'utilise le built-in
- # normalisation L2
- # cf. https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Normalizer.html#sklearn.preprocessing.Normalizer
estimators.append(('normalizer', preproc.Normalizer()))
if balance is not None:
-
print(".......... implementing strategy to solve imbalance in data ........")
if balance == 'downsampling':
estimators.append(('sampling', under.RandomUnderSampler(random_state=42, replacement=False)))
- # if balance == 'ENN':
- # enn = under.EditedNearestNeighbours()
- # train, classes = enn.fit_resample(train, classes)
-
if balance == 'Tomek':
estimators.append(('sampling', under.TomekLinks()))
@@ -109,42 +108,35 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
# In case we have to temper with the n_neighbors, we print a warning message to the user
# (might be written more clearly, but we want a short message, right?)
if 0 < n_neighbors >= min_class_size:
- print(
- f"Warning: Adjusting n_neighbors for SMOTE to {n_neighbors} due to small class size.")
+ print(f"Warning: Adjusting n_neighbors for SMOTE to {n_neighbors} due to small class size.")
if n_neighbors == 0:
- print(
- f"Warning: at least one class only has a single individual; cannot apply SMOTE(Tomek) due to small class size.")
+ print("Warning: at least one class only has a single individual; cannot apply SMOTE(Tomek).")
else:
-
if balance == 'SMOTE':
estimators.append(('sampling', over.SMOTE(k_neighbors=n_neighbors, random_state=42)))
-
elif balance == 'SMOTETomek':
- estimators.append(('sampling', comb.SMOTETomek(random_state=42, smote=over.SMOTE(k_neighbors=n_neighbors, random_state=42))))
+ estimators.append(('sampling', comb.SMOTETomek(
+ random_state=42,
+ smote=over.SMOTE(k_neighbors=n_neighbors, random_state=42)
+ )))
print(".......... choosing SVM ........")
if kernel == "LinearSVC":
- # try a faster one
estimators.append(('model', sk.LinearSVC(class_weight=cw, dual="auto")))
- # classif = sk.LinearSVC()
-
else:
estimators.append(('model', sk.SVC(kernel=kernel, class_weight=cw)))
- # classif = sk.SVC(kernel=kernel)
print(".......... Creating pipeline with steps ........")
print(estimators)
if 'sampling' in [k[0] for k in estimators]:
pipe = imbp.Pipeline(estimators)
-
else:
pipe = skp.Pipeline(estimators)
- # Now, doing leave one out validation or training single SVM with train / test split
-
+ # Cross validation or train/test split
if cross_validate is not None:
works = None
if cross_validate == 'leave-one-out':
@@ -152,34 +144,28 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
if cross_validate == 'k-fold':
if k == 0:
- k = 10 # set default
-
+ k = 10
myCV = skmodel.KFold(n_splits=k)
if cross_validate == 'group-k-fold':
- # Get the groups as the different source texts
works = ["_".join(t.split("_")[:-1]) for t in train.index.values]
if k == 0:
k = len(set(works))
-
myCV = skmodel.GroupKFold(n_splits=k)
- print(".......... " + cross_validate + " cross validation will be performed ........")
- print(".......... using " + str(myCV.get_n_splits(train)) + " samples or groups........")
-
- # Will need to
- # 1. train a model
- # 2. get prediction
- # 3. compute score: precision, recall, F1 for all categories
+ print(f".......... {cross_validate} cross validation will be performed ........")
+ print(f".......... using {myCV.get_n_splits(train)} samples or groups........")
preds = skmodel.cross_val_predict(pipe, train, classes, cv=myCV, verbose=1, n_jobs=-1, groups=works)
# and now, leave one out evaluation (very small redundancy here, one line that could be stored elsewhere)
unique_labels = list(set(classes))
- results["confusion_matrix"] = pandas.DataFrame(metrics.confusion_matrix(classes, preds, labels=unique_labels),
- index=['true:{:}'.format(x) for x in unique_labels],
- columns=['pred:{:}'.format(x) for x in unique_labels])
+ results["confusion_matrix"] = pandas.DataFrame(
+ metrics.confusion_matrix(classes, preds, labels=unique_labels),
+ index=[f'true:{x}' for x in unique_labels],
+ columns=[f'pred:{x}' for x in unique_labels]
+ )
report = metrics.classification_report(classes, preds)
print(report)
@@ -203,14 +189,16 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
else:
pipe.fit(train, classes)
preds = pipe.predict(test)
+
if not final_pred:
# and evaluate
unique_labels = list(set(classes + classes_test))
results["confusion_matrix"] = pandas.DataFrame(
metrics.confusion_matrix(classes_test, preds, labels=unique_labels),
- index=['true:{:}'.format(x) for x in unique_labels],
- columns=['pred:{:}'.format(x) for x in unique_labels])
+ index=[f'true:{x}' for x in unique_labels],
+ columns=[f'pred:{x}' for x in unique_labels]
+ )
report = metrics.classification_report(classes, preds)
print(report)
@@ -222,45 +210,50 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
columns=["id", "True", "Pred"]
).set_index('id')
- # AND NOW, we need to evaluate or create the final predictions
+ # Final predictions
if final_pred:
# Get the decision function too
myclasses = pipe.classes_
decs = pipe.decision_function(test)
dists = {}
+
if len(pipe.classes_) == 2:
results["final_predictions"] = pandas.DataFrame(
- data={**{'filename': preds_index, 'author': list(preds)}, 'Decision function': decs})
-
+ data={**{'filename': preds_index, 'author': list(preds)}, 'Decision function': decs}
+ )
else:
for myclass in enumerate(myclasses):
dists[myclass[1]] = [d[myclass[0]] for d in decs]
- results["final_predictions"] = pandas.DataFrame(
- data={**{'filename': preds_index, 'author': list(preds)}, **dists})
+ results["final_predictions"] = pandas.DataFrame(
+ data={**{'filename': preds_index, 'author': list(preds)}, **dists}
+ )
if get_coefs:
if kernel != "LinearSVC":
print(".......... COEFS ARE ONLY IMPLEMENTED FOR linearSVC ........")
-
else:
# For “one-vs-rest” LinearSVC the attributes coef_ and intercept_ have the shape (n_classes, n_features) and
# (n_classes,) respectively.
# Each row of the coefficients corresponds to one of the n_classes “one-vs-rest” classifiers and similar for the
# intercepts, in the order of the “one” class.
if len(pipe.classes_) == 2:
- results["coefficients"] = pandas.DataFrame(pipe.named_steps['model'].coef_,
- index=[pipe.classes_[0]],
- columns=train.columns)
-
- plot_coefficients(pipe.named_steps['model'].coef_[0], train.columns,
- pipe.classes_[0] + " versus " + pipe.classes_[1])
-
+ results["coefficients"] = pandas.DataFrame(
+ pipe.named_steps['model'].coef_,
+ index=[pipe.classes_[0]],
+ columns=train.columns
+ )
+ plot_coefficients(
+ pipe.named_steps['model'].coef_[0],
+ train.columns,
+ f"{pipe.classes_[0]} versus {pipe.classes_[1]}"
+ )
else:
- results["coefficients"] = pandas.DataFrame(pipe.named_steps['model'].coef_,
- index=pipe.classes_,
- columns=train.columns)
-
+ results["coefficients"] = pandas.DataFrame(
+ pipe.named_steps['model'].coef_,
+ index=pipe.classes_,
+ columns=train.columns
+ )
for i in range(len(pipe.classes_)):
plot_coefficients(pipe.named_steps['model'].coef_[i], train.columns, pipe.classes_[i])
@@ -272,23 +265,28 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True,
# https://aneesha.medium.com/visualising-top-features-in-linear-svm-with-scikit-learn-and-matplotlib-3454ab18a14d
def plot_coefficients(coefs, feature_names, current_class, top_features=10):
- plt.rcParams.update({'font.size': 30}) # increase font size
+ """Plot the most important coefficients for a class."""
+ plt.rcParams.update({'font.size': 30})
top_positive_coefficients = np.argsort(coefs)[-top_features:]
top_negative_coefficients = np.argsort(coefs)[:top_features]
top_coefficients = np.hstack([top_negative_coefficients, top_positive_coefficients])
- # create plot
+
plt.figure(figsize=(15, 8))
colors = ['red' if c < 0 else 'blue' for c in coefs[top_coefficients]]
plt.bar(np.arange(2 * top_features), coefs[top_coefficients], color=colors)
feature_names = np.array(feature_names)
- plt.xticks(np.arange(0, 2 * top_features), feature_names[top_coefficients], rotation=60, ha='right',
- rotation_mode='anchor')
- plt.title("Coefficients for " + current_class)
- plt.savefig('coefs_' + current_class + '.png', bbox_inches='tight')
-
-
-
-def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)"):
+ plt.xticks(
+ np.arange(0, 2 * top_features),
+ feature_names[top_coefficients],
+ rotation=60,
+ ha='right',
+ rotation_mode='anchor'
+ )
+ plt.title(f"Coefficients for {current_class}")
+ plt.savefig(f'coefs_{current_class}.png', bbox_inches='tight')
+
+
+def plot_rolling(final_predictions, smoothing=3, xlab="Index (segment center)"):
"""
Plots the rolling stylometry results as lines of decision function values over the text.
@@ -306,6 +304,7 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)"
# Extract the segment center from the filename
my_final_predictions = final_predictions.copy() # to avoid modifying in place
segment_centers = []
+
for fname in my_final_predictions['filename']:
parts = fname.split('_')[-1].split('-')
start = int(parts[0])
@@ -314,7 +313,6 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)"
segment_centers.append(center)
my_final_predictions['segment_center'] = segment_centers
-
my_final_predictions['filename'] = [fname.split('_')[1] for fname in my_final_predictions['filename']]
# Identify candidate columns
@@ -329,21 +327,23 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)"
# Apply smoothing if requested
if smoothing and smoothing > 0:
for col in candidate_cols:
- fpreds_work[col] = fpreds_work[col].rolling(window=smoothing, center=True, min_periods=1).mean()
+ fpreds_work[col] = fpreds_work[col].rolling(
+ window=smoothing, center=True, min_periods=1
+ ).mean()
# Plotting
plt.figure(figsize=(24, 12))
for col in candidate_cols:
plt.plot(fpreds_work['segment_center'], fpreds_work[col], label=col, linewidth=2)
- plt.title('Rolling Stylometry Decision Functions Over ' + work)
+ plt.title(f'Rolling Stylometry Decision Functions Over {work}')
plt.xlabel(xlab)
plt.ylabel('Decision Function Value')
- plt.ylim(min(-2, min(fpreds_work[candidate_cols].min()) - 0.2),
- max(1, max(fpreds_work[candidate_cols].max())) + 0.2)
+ plt.ylim(
+ min(-2, min(fpreds_work[candidate_cols].min()) - 0.2),
+ max(1, max(fpreds_work[candidate_cols].max())) + 0.2
+ )
plt.legend(title='Candidate Authors', fontsize="small")
plt.grid(True)
plt.tight_layout()
- #plt.show()
- plt.savefig('rolling_'+ work + '.png', bbox_inches='tight')
-
+ plt.savefig(f'rolling_{work}.png', bbox_inches='tight')
\ No newline at end of file
diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py
index cc37116a..1756c0b3 100644
--- a/tests/test_error_handling.py
+++ b/tests/test_error_handling.py
@@ -1,7 +1,8 @@
import unittest
import superstyl.load
import superstyl.preproc.features_extract
-from superstyl.load_from_config import load_corpus_from_config
+from superstyl.load import load_corpus
+from superstyl.config import Config
import os
import tempfile
import json
@@ -33,7 +34,7 @@ def test_load_corpus_lemma_requires_tei(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- self.test_paths,
+ data_paths=self.test_paths,
feats="lemma",
format="txt"
)
@@ -48,7 +49,7 @@ def test_load_corpus_pos_requires_tei(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- self.test_paths,
+ data_paths=self.test_paths,
feats="pos",
format="txt"
)
@@ -63,7 +64,7 @@ def test_load_corpus_met_line_requires_tei(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- self.test_paths,
+ data_paths=self.test_paths,
feats="met_line",
format="txt"
)
@@ -78,7 +79,7 @@ def test_load_corpus_met_syll_requires_tei(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- self.test_paths,
+ data_paths=self.test_paths,
feats="met_syll",
format="txt"
)
@@ -86,8 +87,8 @@ def test_load_corpus_met_syll_requires_tei(self):
self.assertIn("met_syll", str(context.exception))
self.assertIn("tei", str(context.exception).lower())
- def test_load_corpus_met_line_requires_lines_unit(self):
- # SCENARIO: met_line requires units='lines'
+ def test_load_corpus_met_line_requires_verses_unit(self):
+ # SCENARIO: met_line requires units='verses'
# GIVEN: Attempting to use met_line with units='words'
# Create a dummy TEI file for this test
@@ -98,17 +99,18 @@ def test_load_corpus_met_line_requires_lines_unit(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- [tei_path],
+ data_paths=[tei_path],
feats="met_line",
format="tei",
+ sampling=True,
units="words" # Wrong unit type
)
self.assertIn("met_line", str(context.exception))
- self.assertIn("lines", str(context.exception))
+ self.assertIn("verses", str(context.exception))
- def test_load_corpus_met_syll_requires_lines_unit(self):
- # SCENARIO: met_syll requires units='lines'
+ def test_load_corpus_met_syll_requires_verses_unit(self):
+ # SCENARIO: met_syll requires units='verses'
# GIVEN: Attempting to use met_syll with units='words'
# Create a dummy TEI file for this test
@@ -119,14 +121,15 @@ def test_load_corpus_met_syll_requires_lines_unit(self):
# WHEN/THEN: Should raise ValueError
with self.assertRaises(ValueError) as context:
superstyl.load.load_corpus(
- [tei_path],
+ data_paths=[tei_path],
feats="met_syll",
format="tei",
+ sampling=True,
units="words" # Wrong unit type
)
self.assertIn("met_syll", str(context.exception))
- self.assertIn("lines", str(context.exception))
+ self.assertIn("verses", str(context.exception))
# =========================================================================
# Tests pour features_extract.py - ValueError pour paramètres invalides
@@ -248,14 +251,16 @@ def test_load_from_config_with_json_feature_list(self):
# Create config
config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt"
+ },
"features": [
{
"name": "test_feature",
"type": "words",
"n": 1,
- "feat_list": feature_list_path # JSON feature list
+ "feat_list_path": feature_list_path # JSON feature list path
}
]
}
@@ -265,7 +270,8 @@ def test_load_from_config_with_json_feature_list(self):
json.dump(config, f)
# WHEN: Loading corpus from config
- corpus, features = load_corpus_from_config(config_path)
+ config = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config)
# THEN: Should load successfully with JSON feature list
self.assertIsNotNone(corpus)
@@ -283,20 +289,22 @@ def test_load_from_config_test_mode_uses_feat_list(self):
# Create config with multiple features (triggers is_test logic)
config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt"
+ },
"features": [
{
"name": "feat1",
"type": "words",
"n": 1,
- "feat_list": feature_list_path
+ "feat_list_path": feature_list_path
},
{
"name": "feat2",
"type": "chars",
"n": 2,
- "feat_list": feature_list_path
+ "feat_list_path": feature_list_path
}
]
}
@@ -306,7 +314,8 @@ def test_load_from_config_test_mode_uses_feat_list(self):
json.dump(config, f)
# WHEN: Loading corpus from config
- corpus, features = load_corpus_from_config(config_path, is_test=True)
+ config_obj = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config_obj, use_provided_feat_list=True)
# THEN: Should use the provided feature list
self.assertIsNotNone(corpus)
diff --git a/tests/test_load_corpus.py b/tests/test_load_corpus.py
index 076aa199..bea27e35 100644
--- a/tests/test_load_corpus.py
+++ b/tests/test_load_corpus.py
@@ -6,6 +6,7 @@
import superstyl.preproc.embedding
import superstyl.preproc.select
import superstyl.preproc.text_count
+from superstyl.config import NormalizationConfig, Config
import os
import glob
@@ -19,7 +20,7 @@ class Main(unittest.TestCase):
def test_load_corpus(self):
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths)
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths)
# THEN
expected_feats = [('this', 2/12), ('is', 2/12), ('the', 2/12), ('text', 2/12), ('voici', 1/12),
('le', 1/12), ('texte', 1/12), ('also', 1/12)]
@@ -37,7 +38,7 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, culling=50)
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, culling=50)
# THEN
expected_feats = [('this', 2 / 12), ('is', 2 / 12), ('the', 2 / 12), ('text', 2 / 12)]
expected_corpus = {
@@ -51,7 +52,7 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feat_list=[('the', 0)], feats="chars", n=3, k=5000, freqsType="absolute",
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feat_list=[('the', 0)], feats="chars", n=3, k=5000, freqsType="absolute",
format="txt", keep_punct=False, keep_sym=False, identify_lang=True)
# THEN
@@ -65,7 +66,7 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feats="words", n=1,
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="words", n=1,
sampling=True, units="words", size=2, step=None,
keep_punct=True, keep_sym=False)
@@ -127,13 +128,13 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, k=4)
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, k=4)
# THEN
expected_feats = [('this', 2 / 12), ('is', 2 / 12), ('the', 2 / 12), ('text', 2 / 12)]
self.assertEqual(feats, expected_feats)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feats="chars", n=3, format="txt", keep_punct=True,
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="chars", n=3, format="txt", keep_punct=True,
freqsType="absolute")
# THEN
@@ -186,7 +187,7 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feats="chars", n=3, format="txt", keep_punct=True,
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="chars", n=3, format="txt", keep_punct=True,
freqsType="binary")
# THEN
@@ -242,7 +243,7 @@ def test_load_corpus(self):
self.assertEqual(corpus.to_dict(), expected_corpus)
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feats="affixes", n=3, format="txt", keep_punct=True)
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="affixes", n=3, format="txt", keep_punct=True)
# THEN
expected_feats = [('_te', 3/51), ('tex', 3/51), ('ext', 2/51), ('is_', 3/51), ('Thi', 2/51), ('his', 2/51),
@@ -290,7 +291,7 @@ def test_load_corpus(self):
# Now, test embedding
# WHEN
- corpus, feats = superstyl.load.load_corpus(self.paths, feats="words", n=1, format="txt",
+ corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="words", n=1, format="txt",
embedding=THIS_DIR+"/embed/test_embedding.wv.txt",
neighbouring_size=1)
# THEN
@@ -312,8 +313,9 @@ def test_load_texts_txt(self):
# SCENARIO: from paths to txt, get myTexts object, i.e., a list of dictionaries
# # for each text or samples, with metadata and the text itself
# WHEN
- results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=False,
- keep_sym=False, max_samples=None)
+ config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=False,
+ keep_sym=False, max_samples=None)
+ results = superstyl.preproc.pipe.load_texts(self.paths, config)
# THEN
expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'voici le texte', 'lang': 'NA'},
{'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'this is the text', 'lang': 'NA'},
@@ -323,14 +325,16 @@ def test_load_texts_txt(self):
self.assertEqual(results, expected)
# WHEN
- results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=False,
- keep_sym=False, max_samples=1)
+ config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=False,
+ keep_sym=False, max_samples=1)
+ results = superstyl.preproc.pipe.load_texts(self.paths, config)
# THEN
self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 1)
# WHEN
- results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=True,
- keep_sym=False, max_samples=None)
+ config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=True,
+ keep_sym=False, max_samples=None)
+ results = superstyl.preproc.pipe.load_texts(self.paths, config)
# THEN
expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'Voici le texte!', 'lang': 'NA'},
{'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'This is the text!', 'lang': 'NA'},
@@ -339,8 +343,9 @@ def test_load_texts_txt(self):
self.assertEqual(results, expected)
# WHEN
- results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt",
- keep_sym=True, max_samples=None)
+ config = Config.from_kwargs(identify_lang=False, format="txt",
+ keep_sym=True, max_samples=None)
+ results = superstyl.preproc.pipe.load_texts(self.paths, config)
# THEN
expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'Voici le texte!', 'lang': 'NA'},
{'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'This is the text!', 'lang': 'NA'},
@@ -349,8 +354,9 @@ def test_load_texts_txt(self):
self.assertEqual(results, expected)
# WHEN
- results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=True, format="txt", keep_punct=True,
- keep_sym=False, max_samples=None)
+ config = Config.from_kwargs(identify_lang=True, format="txt", keep_punct=True,
+ keep_sym=False, max_samples=None)
+ results = superstyl.preproc.pipe.load_texts(self.paths, config)
# THEN
# Just testing that a lang is predicted, not if it is ok or not
self.assertEqual(len([text for text in results if text["lang"] != 'NA']), 3)
@@ -359,8 +365,9 @@ def test_load_texts_txt(self):
def test_docs_to_samples(self):
# WHEN
- results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None, units="words",
- format="txt", keep_punct=False, keep_sym=False, max_samples=None)
+ config = Config.from_kwargs(identify_lang=False, size=2, step=None, units="words",
+ format="txt", keep_punct=False, keep_sym=False, max_samples=None)
+ results = superstyl.preproc.pipe.docs_to_samples(self.paths, config)
# THEN
expected = [{'name': 'Dupont_Letter1.txt_0-2', 'aut': 'Dupont', 'text': 'voici le', 'lang': 'NA'},
{'name': 'Smith_Letter1.txt_0-2', 'aut': 'Smith', 'text': 'this is', 'lang': 'NA'},
@@ -370,10 +377,11 @@ def test_docs_to_samples(self):
self.assertEqual(results, expected)
# WHEN
- results = superstyl.preproc.pipe.docs_to_samples(sorted(self.paths), identify_lang=False, size=2, step=1,
- units="words", format="txt", keep_punct=True,
- keep_sym=True,
- max_samples=None)
+ config = Config.from_kwargs(identify_lang=False, size=2, step=1,
+ units="words", format="txt", keep_punct=True,
+ keep_sym=True,
+ max_samples=None)
+ results = superstyl.preproc.pipe.docs_to_samples(sorted(self.paths), config)
# THEN
expected = [{'name': 'Dupont_Letter1.txt_0-2', 'aut': 'Dupont', 'text': 'Voici le', 'lang': 'NA'},
@@ -396,37 +404,43 @@ def test_docs_to_samples(self):
self.assertEqual(results, expected)
# WHEN
- results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=True, size=2, step=None,
- units="words", format="txt", keep_punct=False,
- keep_sym=False,
- max_samples=None)
+ config = Config.from_kwargs(identify_lang=True, size=2, step=None,
+ units="words", format="txt", keep_punct=False,
+ keep_sym=False,
+ max_samples=None)
+ results = superstyl.preproc.pipe.docs_to_samples(self.paths, config)
# THEN
self.assertEqual(len([text for text in results if text["lang"] != 'NA']), 5)
# WHEN
- results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None,
- units="words", format="txt", keep_punct=False,
- keep_sym=False,
- max_samples=1)
+ config = Config.from_kwargs(identify_lang=False, size=2, step=None,
+ units="words", format="txt", keep_punct=False,
+ keep_sym=False,
+ max_samples=1)
+ results = superstyl.preproc.pipe.docs_to_samples(self.paths, config)
# THEN
self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 1)
# TODO: this is just minimal testing for random sampling
# WHEN
- results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None,
- units="words",
- format="txt", keep_punct=False, keep_sym=False,
- max_samples=5, samples_random=True)
+ config = Config.from_kwargs(identify_lang=False, size=2, step=None,
+ units="words",
+ format="txt", keep_punct=False, keep_sym=False,
+ max_samples=5, samples_random=True)
+ results = superstyl.preproc.pipe.docs_to_samples(self.paths, config)
# THEN
self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 5)
# and now tests that error are raised when parameters combinations are not consistent
# WHEN/THEN
- self.assertRaises(ValueError, superstyl.preproc.pipe.docs_to_samples, self.paths, size=2, step=1, units="words",
- format="txt", max_samples=5, samples_random=True)
- self.assertRaises(ValueError, superstyl.preproc.pipe.docs_to_samples, self.paths, size=2, units="words",
- format="txt", max_samples=None,
- samples_random=True)
+ with self.assertRaises(ValueError):
+ Config.from_kwargs(size=2, step=1, units="words",
+ format="txt", max_samples=5, samples_random=True)
+
+ with self.assertRaises(ValueError):
+ Config.from_kwargs(size=2, units="words",
+ format="txt", max_samples=None,
+ samples_random=True)
# TODO: test other loading formats with sampling, that are not txt (and decide on their implementation)
@@ -553,11 +567,6 @@ def test_get_counts(self):
self.assertEqual(results, expected)
- # TODO: test count_process
-
- # TODO: test features_select
- # TODO: test select
-
class DataLoading(unittest.TestCase):
@@ -569,27 +578,32 @@ def test_normalise(self):
# SCENARIO
# GIVEN
text = " Hello, Mr. 𓀁, how are §§ you; doing? ſõ ❡"
+ norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = False)
# WHEN
- results = superstyl.preproc.pipe.normalise(text)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_default = "hello mr how are you doing s o"
self.assertEqual(results, expected_default)
# WHEN
- results = superstyl.preproc.pipe.normalise(text, no_ascii=True)
+ norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = True)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_default = "hello mr 𓀁 how are you doing ſ õ"
self.assertEqual(results, expected_default)
# WHEN
- results = superstyl.preproc.pipe.normalise(text, keep_punct=True)
+ norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = False)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_keeppunct = "Hello, Mr. , how are SSSS you; doing? s o"
# WHEN
- results = superstyl.preproc.pipe.normalise(text, keep_punct=True, no_ascii=True)
+ norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = True)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_keeppunct = "Hello, Mr. 𓀁, how are §§ you; doing? ſ õ"
self.assertEqual(results, expected_keeppunct)
# WHEN
- results = superstyl.preproc.pipe.normalise(text, keep_sym=True)
+ norm_conf = NormalizationConfig(keep_punct = False, keep_sym = True, no_ascii = False)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_keepsym = "Hello, Mr. 𓀁, how are §§ you; doing? ſ\uf217õ ❡"
self.assertEqual(results, expected_keepsym)
@@ -598,7 +612,8 @@ def test_normalise(self):
# GIVEN
text = 'Coucou 😅'
# WHEN
- results = superstyl.preproc.pipe.normalise(text, keep_sym=True)
+ norm_conf = NormalizationConfig(keep_punct = False, keep_sym = True, no_ascii = False)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_keepsym = 'Coucou 😅'
self.assertEqual(results, expected_keepsym)
@@ -606,16 +621,18 @@ def test_normalise(self):
# gives: 'Coucou 😵 💫'
# because of the way NFC normalisation is handled probably
- # Test for Armenian
+ # Test for Armenian
# GIVEN
text = " քան զսակաւս ։ Ահա նշանագրեցի"
# WHEN
- results = superstyl.preproc.pipe.normalise(text, no_ascii=True)
+ norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = True)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_default = "քան զսակաւս ահա նշանագրեցի"
self.assertEqual(results, expected_default)
# WHEN
- results = superstyl.preproc.pipe.normalise(text, keep_punct=True, no_ascii=True)
+ norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = True)
+ results = superstyl.preproc.pipe.normalise(text, norm_conf)
# THEN
expected_keeppunct = "քան զսակաւս ։ Ահա նշանագրեցի"
self.assertEqual(results, expected_keeppunct)
diff --git a/tests/test_load_from_config.py b/tests/test_load_from_config.py
index 678230ca..633c7688 100644
--- a/tests/test_load_from_config.py
+++ b/tests/test_load_from_config.py
@@ -6,8 +6,8 @@
import sys
import glob
-from superstyl.load_from_config import load_corpus_from_config
-
+from superstyl.load import load_corpus
+from superstyl.config import Config
# Add parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -25,8 +25,11 @@ def setUp(self):
# Create a test configuration
self.config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ "identify_lang": False
+ },
"sampling": {
"enabled": False
},
@@ -48,8 +51,10 @@ def setUp(self):
# Create a single feature configuration for testing
self.single_config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ },
"features": [
{
"name": "words",
@@ -77,9 +82,8 @@ def test_load_corpus_from_json_config_multiple_features(self):
# GIVEN: Config file with multiple feature specifications
# WHEN: Loading corpus from config
- corpus, features = load_corpus_from_config(
- config_path=self.json_config_path
- )
+ json_config = Config.from_json(self.json_config_path)
+ corpus, features = load_corpus(config=json_config)
# THEN: Corpus and features are loaded correctly
self.assertIsInstance(corpus, pd.DataFrame)
@@ -97,9 +101,8 @@ def test_load_corpus_single_feature(self):
# GIVEN: Config file with a single feature specification
# WHEN: Loading corpus from config
- corpus, features = load_corpus_from_config(
- config_path=self.single_config_path
- )
+ config = Config.from_json(self.single_config_path)
+ corpus, features = load_corpus(config=config)
# THEN: Corpus and features are loaded correctly without prefix
self.assertIsInstance(corpus, pd.DataFrame)
@@ -115,16 +118,15 @@ def test_load_corpus_with_sampling(self):
# GIVEN: Config with sampling enabled and sample size defined
sampling_config = self.config.copy()
sampling_config["sampling"]["enabled"] = True
- sampling_config["sampling"]["sample_size"] = 2
+ sampling_config["sampling"]["size"] = 2
sampling_config_path = os.path.join(self.temp_dir.name, "sampling_config.json")
with open(sampling_config_path, 'w') as f:
json.dump(sampling_config, f)
# WHEN: Loading corpus from config with sampling
- corpus, features = load_corpus_from_config(
- config_path=sampling_config_path
- )
+ config = Config.from_json(sampling_config_path)
+ corpus, features = load_corpus(config=config)
# THEN: Samples are created and file names contain segment info
first_corpus_index = corpus.index[0]
@@ -141,16 +143,15 @@ def test_load_corpus_with_feature_list(self):
# Update config to use feature list
feature_list_config = self.single_config.copy()
- feature_list_config["features"][0]["feat_list"] = feature_list_path
+ feature_list_config["features"][0]["feat_list_path"] = feature_list_path
config_path = os.path.join(self.temp_dir.name, "feature_list_config.json")
with open(config_path, 'w') as f:
json.dump(feature_list_config, f)
# WHEN: Loading corpus from config with feature list
- corpus, features = load_corpus_from_config(
- config_path=config_path
- )
+ config = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config)
# THEN: Only features from the predefined list are used
feature_words = [f[0] for f in features]
@@ -169,8 +170,8 @@ def test_invalid_config_format(self):
f.write("invalid: format")
# WHEN/THEN: Loading corpus from invalid config raises ValueError
- with self.assertRaises(ValueError):
- load_corpus_from_config(invalid_path)
+ with self.assertRaises(Exception): # Will be json.JSONDecodeError
+ Config.from_json(invalid_path)
def test_feature_list_path_not_found(self):
"""Test handling of a non-existent feature list path"""
@@ -178,7 +179,7 @@ def test_feature_list_path_not_found(self):
nonexistent_path = os.path.join(self.temp_dir.name, "nonexistent.json")
feature_list_config = self.single_config.copy()
- feature_list_config["features"][0]["feat_list"] = nonexistent_path
+ feature_list_config["features"][0]["feat_list_path"] = nonexistent_path
config_path = os.path.join(self.temp_dir.name, "missing_feature_list_config.json")
with open(config_path, 'w') as f:
@@ -186,30 +187,38 @@ def test_feature_list_path_not_found(self):
# Should raise FileNotFoundError when trying to load non-existent feature list
with self.assertRaises(FileNotFoundError):
- load_corpus_from_config(config_path)
+ Config.from_json(config_path)
def test_missing_features_in_config(self):
"""Test handling of config with no features specified"""
# Create config without features key
no_features_config = {
- "paths": self.test_paths,
- "format": "txt"
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt"
+ }
}
config_path = os.path.join(self.temp_dir.name, "no_features_config.json")
with open(config_path, 'w') as f:
json.dump(no_features_config, f)
- # Should raise ValueError when no features are specified
- with self.assertRaises(ValueError):
- load_corpus_from_config(config_path)
+ # Should use default feature when no features are specified
+ config = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config)
+
+ # Should work with default features
+ self.assertIsNotNone(corpus)
+ self.assertIsNotNone(features)
def test_empty_features_list(self):
"""Test handling of config with empty features list"""
# Create config with empty features list
empty_features_config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ },
"features": []
}
@@ -219,13 +228,16 @@ def test_empty_features_list(self):
# Should raise ValueError when features list is empty
with self.assertRaises(ValueError):
- load_corpus_from_config(config_path)
+ config = Config.from_json(config_path)
+ config.validate()
def test_missing_paths_in_config(self):
"""Test handling of config with missing paths"""
# Create config without paths key
no_paths_config = {
- "format": "txt",
+ "corpus": {
+ "format": "txt",
+ },
"features": [
{
"type": "words",
@@ -239,24 +251,9 @@ def test_missing_paths_in_config(self):
with open(config_path, 'w') as f:
json.dump(no_paths_config, f)
- # Should raise ValueError when no paths are specified
- with self.assertRaises(ValueError):
- load_corpus_from_config(config_path)
-
- def test_string_path_in_config(self):
- """Test handling of config with a string path instead of list"""
- # Create config with a string path
- string_path_config = self.single_config.copy()
- string_path_config["paths"] = self.test_paths[0] # Single string path
-
- config_path = os.path.join(self.temp_dir.name, "string_path_config.json")
- with open(config_path, 'w') as f:
- json.dump(string_path_config, f)
-
- # Should handle string path
- corpus, features = load_corpus_from_config(config_path)
- self.assertIsInstance(corpus, pd.DataFrame)
- self.assertGreater(len(features), 0)
+ # Config should load but paths will be empty list
+ config = Config.from_json(config_path)
+ self.assertEqual(config.corpus.paths, [])
def test_feature_with_txt_list(self):
"""Test loading a feature list from a txt file"""
@@ -267,14 +264,15 @@ def test_feature_with_txt_list(self):
# Create config that uses txt feature list
txt_list_config = self.single_config.copy()
- txt_list_config["features"][0]["feat_list"] = feature_list_path
+ txt_list_config["features"][0]["feat_list_path"] = feature_list_path
config_path = os.path.join(self.temp_dir.name, "txt_list_config.json")
with open(config_path, 'w') as f:
json.dump(txt_list_config, f)
# Should load corpus with feature list from txt
- corpus, features = load_corpus_from_config(config_path)
+ config = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config)
self.assertIsInstance(corpus, pd.DataFrame)
self.assertGreater(len(features), 0)
@@ -282,19 +280,23 @@ def test_all_optional_parameters(self):
"""Test with a config that includes all optional parameters"""
# Create config with all optional parameters
full_config = {
- "paths": self.test_paths,
- "format": "txt",
- "keep_punct": True,
- "keep_sym": True,
- "no_ascii": True,
- "identify_lang": True,
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ "identify_lang": True,
+ },
+ "normalization": {
+ "keep_punct": True,
+ "keep_sym": True,
+ "no_ascii": True,
+ },
"sampling": {
"enabled": True,
"units": "words",
- "sample_size": 2,
- "sample_step": None, # Set to None when sample_random is True
+ "size": 2,
+ "step": None,
"max_samples": 3,
- "sample_random": True
+ "random": True
},
"features": [
{
@@ -303,10 +305,6 @@ def test_all_optional_parameters(self):
"n": 1,
"k": 100,
"freq_type": "relative",
- "keep_punct": True,
- "keep_sym": True,
- "no_ascii": True,
- "identify_lang": True,
"embedding": None,
"neighbouring_size": 5,
"culling": 0
@@ -319,7 +317,8 @@ def test_all_optional_parameters(self):
json.dump(full_config, f)
# Should load corpus with all parameters set
- corpus, features = load_corpus_from_config(config_path)
+ config = Config.from_json(config_path)
+ corpus, features = load_corpus(config=config)
self.assertIsInstance(corpus, pd.DataFrame)
self.assertGreater(len(features), 0)
@@ -340,13 +339,15 @@ def test_feature_list_loading(self):
# Test the JSON feature list loading
json_config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ },
"features": [
{
"type": "words",
"n": 1,
- "feat_list": json_feature_path
+ "feat_list_path": json_feature_path
}
]
}
@@ -355,21 +356,21 @@ def test_feature_list_loading(self):
with open(json_config_path, 'w') as f:
json.dump(json_config, f)
- # Use direct debugging output to confirm code path is being exercised
- print(f"\nTesting JSON feature list: {json_feature_path}")
- print(f"JSON feature list content: {json_feature_list}")
-
- corpus_json, features_json = load_corpus_from_config(json_config_path)
+ # Load and test JSON config
+ json_config_obj = Config.from_json(json_config_path)
+ corpus_json, features_json = load_corpus(config=json_config_obj)
# Now test the TXT feature list loading
txt_config = {
- "paths": self.test_paths,
- "format": "txt",
+ "corpus": {
+ "paths": self.test_paths,
+ "format": "txt",
+ },
"features": [
{
"type": "words",
"n": 1,
- "feat_list": txt_feature_path
+ "feat_list_path": txt_feature_path
}
]
}
@@ -378,12 +379,9 @@ def test_feature_list_loading(self):
with open(txt_config_path, 'w') as f:
json.dump(txt_config, f)
- # Use direct debugging output to confirm code path is being exercised
- print(f"\nTesting TXT feature list: {txt_feature_path}")
- with open(txt_feature_path, 'r') as f:
- print(f"TXT feature list content: {f.read()}")
-
- corpus_txt, features_txt = load_corpus_from_config(txt_config_path)
+ # Load and test TXT config
+ txt_config_obj = Config.from_json(txt_config_path)
+ corpus_txt, features_txt = load_corpus(config=txt_config_obj)
# Basic verification
self.assertIsInstance(corpus_json, pd.DataFrame)
@@ -391,20 +389,29 @@ def test_feature_list_loading(self):
def test_invalid_paths_type(self):
"""Test handling of config with invalid paths type (neither list nor string)"""
- # Create config with invalid paths type (integer)
- invalid_paths_config = self.single_config.copy()
- invalid_paths_config["paths"] = 123 # Not a list or string
-
- config_path = os.path.join(self.temp_dir.name, "invalid_paths_config.json")
- with open(config_path, 'w') as f:
- json.dump(invalid_paths_config, f)
+ # For this test, we need to test at the load_corpus level since
+ # Config.from_dict will accept any type for paths
+
+ # Create a config dict with integer paths (will pass Config creation)
+ invalid_config_dict = {
+ "corpus": {
+ "paths": 123, # Invalid type
+ "format": "txt"
+ },
+ "features": [
+ {
+ "type": "words",
+ "n": 1
+ }
+ ]
+ }
- # Should raise ValueError for invalid paths type
- with self.assertRaises(ValueError) as context:
- load_corpus_from_config(config_path)
+ # Create Config object (this will work)
+ config = Config.from_dict(invalid_config_dict)
- # Verify the error message
- self.assertIn("Paths in config must be either a list or a glob pattern string", str(context.exception))
+ # But load_corpus should fail when trying to iterate paths
+ with self.assertRaises(TypeError):
+ load_corpus(config=config)
if __name__ == "__main__":
diff --git a/tests/test_select.py b/tests/test_select.py
new file mode 100644
index 00000000..b6dcf916
--- /dev/null
+++ b/tests/test_select.py
@@ -0,0 +1,473 @@
+import unittest
+import tempfile
+import os
+import csv
+import json
+import pandas as pd
+from superstyl.preproc.select import (
+ _load_metadata,
+ _should_exclude,
+ read_clean,
+ apply_selection
+)
+
+
+class TestSelectRefactored(unittest.TestCase):
+ """Tests for the refactored select.py module"""
+
+ def setUp(self):
+ """Create temporary directory and test files"""
+ self.temp_dir = tempfile.TemporaryDirectory()
+ self.temp_path = self.temp_dir.name
+
+ # Create a sample CSV file
+ self.csv_path = os.path.join(self.temp_path, "test_data.csv")
+ with open(self.csv_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['Unnamed: 0', 'author', 'lang', 'feature1', 'feature2'])
+ writer.writerow(['doc1', 'Smith', 'en', 0.5, 0.3])
+ writer.writerow(['doc2', 'Dupont', 'fr', 0.4, 0.6])
+ writer.writerow(['doc3', 'Garcia', 'es', 0.7, 0.2])
+ writer.writerow(['doc4', 'Smith', 'en', 0.6, 0.4])
+ writer.writerow(['doc5', 'Dupont', 'fr', 0.3, 0.7])
+
+ def tearDown(self):
+ """Clean up temporary directory"""
+ self.temp_dir.cleanup()
+
+ # =========================================================================
+ # Tests for _load_metadata() helper function
+ # =========================================================================
+
+ def test_load_metadata_with_no_params(self):
+ """Test _load_metadata when no metadata or excludes needed"""
+ # WHEN
+ metadata, excludes = _load_metadata(
+ self.csv_path,
+ metadata_path=None,
+ excludes_path=None,
+ lang=None
+ )
+
+ # THEN
+ self.assertIsNone(metadata)
+ self.assertIsNone(excludes)
+
+ def test_load_metadata_from_main_csv_for_lang(self):
+ """Test _load_metadata creating metadata from main CSV when lang specified"""
+ # WHEN
+ metadata, excludes = _load_metadata(
+ self.csv_path,
+ metadata_path=None,
+ excludes_path=None,
+ lang='en' # Lang specified, so metadata needed
+ )
+
+ # THEN
+ self.assertIsNotNone(metadata)
+ self.assertIsNone(excludes)
+ self.assertIn('doc1', metadata.index)
+ self.assertEqual(metadata.loc['doc1', 'lang'], 'en')
+
+ def test_load_metadata_from_separate_file(self):
+ """Test _load_metadata loading from separate metadata file"""
+ # GIVEN: Create a metadata file
+ metadata_path = os.path.join(self.temp_path, "metadata.csv")
+ with open(metadata_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['id', 'true'])
+ writer.writerow(['doc1', 'en'])
+ writer.writerow(['doc2', 'fr'])
+
+ # WHEN
+ metadata, excludes = _load_metadata(
+ self.csv_path,
+ metadata_path=metadata_path,
+ excludes_path=None,
+ lang=None
+ )
+
+ # THEN
+ self.assertIsNotNone(metadata)
+ self.assertIsNone(excludes)
+ self.assertEqual(metadata.loc['doc1', 'lang'], 'en')
+ self.assertEqual(metadata.loc['doc2', 'lang'], 'fr')
+
+ def test_load_metadata_with_excludes(self):
+ """Test _load_metadata loading excludes list"""
+ # GIVEN: Create an excludes file
+ excludes_path = os.path.join(self.temp_path, "excludes.csv")
+ with open(excludes_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['id'])
+ writer.writerow(['doc3'])
+ writer.writerow(['doc5'])
+
+ # WHEN
+ metadata, excludes = _load_metadata(
+ self.csv_path,
+ metadata_path=None,
+ excludes_path=excludes_path,
+ lang=None
+ )
+
+ # THEN
+ # When excludes_path is not None, metadata is loaded from CSV (for potential filtering)
+ self.assertIsNotNone(metadata)
+ self.assertIsNotNone(excludes)
+ self.assertIn('doc3', excludes)
+ self.assertIn('doc5', excludes)
+ self.assertEqual(len(excludes), 2)
+
+ # =========================================================================
+ # Tests for _should_exclude() helper function
+ # =========================================================================
+
+ def test_should_exclude_no_exclusions(self):
+ """Test _should_exclude when no exclusions apply"""
+ # WHEN
+ is_excluded, reason = _should_exclude(
+ 'doc1',
+ metadata=None,
+ excludes=None,
+ lang=None
+ )
+
+ # THEN
+ self.assertFalse(is_excluded)
+ self.assertIsNone(reason)
+
+ def test_should_exclude_by_language(self):
+ """Test _should_exclude excluding by language"""
+ # GIVEN: Create metadata
+ metadata = pd.DataFrame({
+ 'lang': ['en', 'fr', 'es']
+ }, index=['doc1', 'doc2', 'doc3'])
+
+ # WHEN: Check doc with wrong language
+ is_excluded, reason = _should_exclude(
+ 'doc2', # fr
+ metadata=metadata,
+ excludes=None,
+ lang='en' # Only want en
+ )
+
+ # THEN
+ self.assertTrue(is_excluded)
+ self.assertIn('not in: en', reason)
+ self.assertIn('doc2', reason)
+
+ def test_should_exclude_by_language_passes(self):
+ """Test _should_exclude NOT excluding when language matches"""
+ # GIVEN: Create metadata
+ metadata = pd.DataFrame({
+ 'lang': ['en', 'fr', 'es']
+ }, index=['doc1', 'doc2', 'doc3'])
+
+ # WHEN: Check doc with correct language
+ is_excluded, reason = _should_exclude(
+ 'doc1', # en
+ metadata=metadata,
+ excludes=None,
+ lang='en' # Want en
+ )
+
+ # THEN
+ self.assertFalse(is_excluded)
+ self.assertIsNone(reason)
+
+ def test_should_exclude_by_excludes_list(self):
+ """Test _should_exclude excluding by excludes list"""
+ # WHEN
+ is_excluded, reason = _should_exclude(
+ 'doc3',
+ metadata=None,
+ excludes=['doc3', 'doc5'],
+ lang=None
+ )
+
+ # THEN
+ self.assertTrue(is_excluded)
+ self.assertIn('Wilhelmus', reason)
+ self.assertIn('doc3', reason)
+
+ def test_should_exclude_not_in_excludes_list(self):
+ """Test _should_exclude NOT excluding when not in list"""
+ # WHEN
+ is_excluded, reason = _should_exclude(
+ 'doc1',
+ metadata=None,
+ excludes=['doc3', 'doc5'],
+ lang=None
+ )
+
+ # THEN
+ self.assertFalse(is_excluded)
+ self.assertIsNone(reason)
+
+ def test_should_exclude_keyerror_handling(self):
+ """Test _should_exclude handling missing keys gracefully"""
+ # GIVEN: Metadata without doc99
+ metadata = pd.DataFrame({
+ 'lang': ['en', 'fr']
+ }, index=['doc1', 'doc2'])
+
+ # WHEN: Check a doc not in metadata
+ is_excluded, reason = _should_exclude(
+ 'doc99', # Not in metadata
+ metadata=metadata,
+ excludes=None,
+ lang='en'
+ )
+
+ # THEN: Should not crash, should not exclude
+ self.assertFalse(is_excluded)
+ self.assertIsNone(reason)
+
+ # =========================================================================
+ # Tests for read_clean() function
+ # =========================================================================
+
+ def test_read_clean_no_split(self):
+ """Test read_clean without splitting"""
+ # GIVEN
+ output_json = os.path.join(self.temp_path, "selection.json")
+
+ # WHEN
+ read_clean(
+ self.csv_path,
+ savesplit=output_json,
+ split=False
+ )
+
+ # THEN: Should create _selected.csv
+ selected_path = self.csv_path.replace('.csv', '_selected.csv')
+ self.assertTrue(os.path.exists(selected_path))
+
+ # Check content
+ with open(selected_path, 'r') as f:
+ lines = f.readlines()
+ self.assertEqual(len(lines), 6) # Header + 5 data rows
+
+ # Check selection JSON
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ self.assertIn('train', selection)
+ self.assertIn('elim', selection)
+ self.assertNotIn('valid', selection) # No split
+ self.assertEqual(len(selection['train']), 5)
+ self.assertEqual(len(selection['elim']), 0)
+
+ def test_read_clean_with_split(self):
+ """Test read_clean with splitting into train/valid"""
+ # GIVEN
+ output_json = os.path.join(self.temp_path, "selection_split.json")
+
+ # WHEN
+ read_clean(
+ self.csv_path,
+ savesplit=output_json,
+ split=True,
+ split_ratio=0.5 # 50% for easier testing
+ )
+
+ # THEN: Should create _train.csv and _valid.csv
+ train_path = self.csv_path.replace('.csv', '_train.csv')
+ valid_path = self.csv_path.replace('.csv', '_valid.csv')
+ self.assertTrue(os.path.exists(train_path))
+ self.assertTrue(os.path.exists(valid_path))
+
+ # Check selection JSON
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ self.assertIn('train', selection)
+ self.assertIn('valid', selection)
+ self.assertIn('elim', selection)
+ # With 50% split, should have some in train and some in valid
+ total = len(selection['train']) + len(selection['valid'])
+ self.assertEqual(total, 5) # All 5 docs processed
+
+ def test_read_clean_with_language_filter(self):
+ """Test read_clean filtering by language"""
+ # GIVEN
+ output_json = os.path.join(self.temp_path, "selection_lang.json")
+
+ # WHEN: Only keep English docs
+ read_clean(
+ self.csv_path,
+ savesplit=output_json,
+ lang='en',
+ split=False
+ )
+
+ # THEN
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ # Should keep 2 English docs (doc1, doc4)
+ self.assertEqual(len(selection['train']), 2)
+ # Should eliminate 3 non-English docs (doc2, doc3, doc5)
+ self.assertEqual(len(selection['elim']), 3)
+
+ def test_read_clean_with_excludes(self):
+ """Test read_clean with excludes list"""
+ # GIVEN: Create an excludes file
+ excludes_path = os.path.join(self.temp_path, "excludes.csv")
+ with open(excludes_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['id'])
+ writer.writerow(['doc2'])
+ writer.writerow(['doc4'])
+
+ output_json = os.path.join(self.temp_path, "selection_excl.json")
+
+ # WHEN
+ read_clean(
+ self.csv_path,
+ excludes_path=excludes_path,
+ savesplit=output_json,
+ split=False
+ )
+
+ # THEN
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ # Should keep 3 docs (doc1, doc3, doc5)
+ self.assertEqual(len(selection['train']), 3)
+ # Should eliminate 2 docs (doc2, doc4)
+ self.assertEqual(len(selection['elim']), 2)
+ self.assertIn('doc2', selection['elim'])
+ self.assertIn('doc4', selection['elim'])
+
+ def test_read_clean_with_metadata_file(self):
+ """Test read_clean with separate metadata file"""
+ # GIVEN: Create a metadata file
+ metadata_path = os.path.join(self.temp_path, "metadata.csv")
+ with open(metadata_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['id', 'true'])
+ writer.writerow(['doc1', 'en'])
+ writer.writerow(['doc2', 'en'])
+ writer.writerow(['doc3', 'fr'])
+ writer.writerow(['doc4', 'en'])
+ writer.writerow(['doc5', 'fr'])
+
+ output_json = os.path.join(self.temp_path, "selection_meta.json")
+
+ # WHEN: Filter by English using metadata file
+ read_clean(
+ self.csv_path,
+ metadata_path=metadata_path,
+ savesplit=output_json,
+ lang='en',
+ split=False
+ )
+
+ # THEN
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ # Should keep 3 English docs according to metadata
+ self.assertEqual(len(selection['train']), 3)
+ self.assertIn('doc1', selection['train'])
+ self.assertIn('doc2', selection['train'])
+ self.assertIn('doc4', selection['train'])
+
+ def test_read_clean_combined_filters(self):
+ """Test read_clean with both language filter and excludes"""
+ # GIVEN: Create an excludes file
+ excludes_path = os.path.join(self.temp_path, "excludes_combined.csv")
+ with open(excludes_path, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerow(['id'])
+ writer.writerow(['doc4']) # Exclude one English doc
+
+ output_json = os.path.join(self.temp_path, "selection_combined.json")
+
+ # WHEN: Filter by English AND exclude doc4
+ read_clean(
+ self.csv_path,
+ excludes_path=excludes_path,
+ savesplit=output_json,
+ lang='en',
+ split=False
+ )
+
+ # THEN
+ with open(output_json, 'r') as f:
+ selection = json.load(f)
+ # Should keep only doc1 (doc4 excluded, others wrong lang)
+ self.assertEqual(len(selection['train']), 1)
+ self.assertIn('doc1', selection['train'])
+ # Should eliminate 4 docs
+ self.assertEqual(len(selection['elim']), 4)
+
+ # =========================================================================
+ # Tests for apply_selection() function
+ # =========================================================================
+
+ def test_apply_selection(self):
+ """Test apply_selection applying a pre-existing selection"""
+ # GIVEN: Create a selection JSON
+ selection_path = os.path.join(self.temp_path, "presplit.json")
+ selection = {
+ 'train': ['doc1', 'doc3'],
+ 'valid': ['doc2', 'doc4'],
+ 'elim': ['doc5']
+ }
+ with open(selection_path, 'w') as f:
+ json.dump(selection, f)
+
+ # WHEN
+ apply_selection(self.csv_path, selection_path)
+
+ # THEN: Should create _train.csv and _valid.csv
+ train_path = self.csv_path.replace('.csv', '_train.csv')
+ valid_path = self.csv_path.replace('.csv', '_valid.csv')
+ self.assertTrue(os.path.exists(train_path))
+ self.assertTrue(os.path.exists(valid_path))
+
+ # Check train file content
+ with open(train_path, 'r') as f:
+ reader = csv.reader(f)
+ next(reader) # Skip header
+ train_ids = [row[0] for row in reader]
+ self.assertEqual(set(train_ids), {'doc1', 'doc3'})
+
+ # Check valid file content
+ with open(valid_path, 'r') as f:
+ reader = csv.reader(f)
+ next(reader) # Skip header
+ valid_ids = [row[0] for row in reader]
+ self.assertEqual(set(valid_ids), {'doc2', 'doc4'})
+
+ def test_apply_selection_eliminates_correctly(self):
+ """Test apply_selection correctly eliminates specified docs"""
+ # GIVEN: Create a selection JSON with eliminations
+ selection_path = os.path.join(self.temp_path, "presplit_elim.json")
+ selection = {
+ 'train': ['doc1', 'doc2'],
+ 'valid': ['doc3'],
+ 'elim': ['doc4', 'doc5']
+ }
+ with open(selection_path, 'w') as f:
+ json.dump(selection, f)
+
+ # WHEN
+ apply_selection(self.csv_path, selection_path)
+
+ # THEN: Eliminated docs should not appear in either file
+ train_path = self.csv_path.replace('.csv', '_train.csv')
+ valid_path = self.csv_path.replace('.csv', '_valid.csv')
+
+ with open(train_path, 'r') as f:
+ content = f.read()
+ self.assertNotIn('doc4', content)
+ self.assertNotIn('doc5', content)
+
+ with open(valid_path, 'r') as f:
+ content = f.read()
+ self.assertNotIn('doc4', content)
+ self.assertNotIn('doc5', content)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/train_svm.py b/train_svm.py
index c1a30333..747ff587 100755
--- a/train_svm.py
+++ b/train_svm.py
@@ -1,27 +1,48 @@
import superstyl.svm
+from superstyl.config import Config
import pandas
import joblib
if __name__ == "__main__":
-
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(
+ description="Train SVM models for stylometric analysis.")
+
+ # Configuration file
+ parser.add_argument('--config', action='store', type=str, default=None,
+ help="Path to a JSON configuration file"
+ )
+
+ # Data paths
parser.add_argument('train_path', action='store', help="Path to train file", type=str)
parser.add_argument('--test_path', action='store', help="Path to test file", type=str, required=False,
default=None)
+
+ # Output options
parser.add_argument('-o', action='store', help="optional prefix for output files", type=str,
default=None)
+
+
+ # Cross-validation options
parser.add_argument('--cross_validate', action='store',
help="perform cross validation (test_path will be used only for final prediction)."
"If group_k-fold is chosen, each source file will be considered a group "
"(only relevant if sampling was performed and more than one file per class was provided)",
default=None, choices=['leave-one-out', 'k-fold', 'group-k-fold'], type=str)
- parser.add_argument('--k', action='store', help="k for k-fold (default: 10 folds for k-fold; k=n groups for group-k-fold)", default=0, type=int)
- parser.add_argument('--dim_reduc', action='store', choices=['pca'], help="optional dimensionality "
- "reduction of input data", default=None)
+
+ parser.add_argument('--k', action='store',
+ help="k for k-fold (default: 10 folds for k-fold; k=n groups for group-k-fold)",
+ default=0, type=int)
+
+ # Preprocessing options
+ parser.add_argument('--dim_reduc', action='store', choices=['pca'],
+ help="Dimensionality reduction of input data", default=None)
+
parser.add_argument('--norms', action='store_true', help="perform normalisations? (default: True)",
default=True)
+
+ # Balancing options
parser.add_argument('--balance', action='store',
choices=["downsampling", "Tomek", "upsampling", "SMOTE", "SMOTETomek"],
help="which "
@@ -32,31 +53,43 @@
"SMOTE (upsampling with SMOTE), "
"SMOTETomek (over+undersampling with SMOTE+Tomek)",
default=None)
+
parser.add_argument('--class_weights', action='store_true',
help="whether to use class weights in imbalanced datasets "
- "(inversely proportional to total class samples)",
- default=False
- )
+ "(inversely proportional to total class samples)", default=False)
+
parser.add_argument('--kernel', action='store',
help="type of kernel to use (default and recommended choice is LinearSVC; "
"possible alternatives are linear, sigmoid, rbf and poly, as per sklearn.svm.SVC)",
default="LinearSVC", choices=['LinearSVC', 'linear', 'sigmoid', 'rbf', 'poly'], type=str)
+
+ # Output options
parser.add_argument('--final', action='store_true',
help="final analysis on unknown dataset (no evaluation)?",
default=False)
+
parser.add_argument('--get_coefs', action='store_true',
help="switch to write to disk and plot the most important coefficients"
" for the training feats for each class",
default=False)
- # New arguments for rolling stylometry plotting
+
+ # Rolling stylometry plotting
parser.add_argument('--plot_rolling', action='store_true',
help="If final predictions are produced, also plot rolling stylometry.")
+
parser.add_argument('--plot_smoothing', action='store', type=int, default=3,
help="Smoothing window size for rolling stylometry plot (default:3)."
"Set to 0 or None to disable smoothing.")
+
+ # Save configuration
+ parser.add_argument(
+ '--save_config', action='store', type=str, default=None,
+ help="Save the configuration to a JSON file"
+ )
args = parser.parse_args()
+ # Load data
print(".......... loading data ........")
train = pandas.read_csv(args.train_path, index_col=0)
@@ -65,31 +98,49 @@
else:
test = None
- svm = superstyl.svm.train_svm(train, test, cross_validate=args.cross_validate, k=args.k, dim_reduc=args.dim_reduc,
- norms=args.norms, balance=args.balance, class_weights=args.class_weights,
- kernel=args.kernel, final_pred=args.final, get_coefs=args.get_coefs)
+ # Determine how to run the SVM
+ if args.config:
+ config = Config.from_json(args.config)
+ # Override with CLI arguments if provided
- if args.o is not None:
- args.o = args.o + "_"
else:
- args.o = ''
+ config = Config.from_kwargs(
+ cross_validate=args.cross_validate,
+ dim_reduc=args.dim_reduc,
+ norms=args.norms,
+ balance=args.balance,
+ class_weights=args.class_weights,
+ kernel=args.kernel,
+ final_pred=args.final,
+ get_coefs=args.get_coefs
+ )
+
+ # Save config if requested
+ if args.save_config:
+ config.save(args.save_config)
+ print(f"Configuration saved to {args.save_config}")
+
+ # Train SVM
+ svm = superstyl.svm.train_svm(train, test, config=config)
+
+ # Output prefix
+ prefix = f"{args.o}_" if args.o else ""
- if args.cross_validate is not None or (args.test_path is not None and not args.final):
- svm["confusion_matrix"].to_csv(args.o+"confusion_matrix.csv")
- svm["misattributions"].to_csv(args.o+"misattributions.csv")
+ # Save results
+ if args.cross_validate or (args.test_path and not args.final):
+ svm["confusion_matrix"].to_csv(f"{prefix}confusion_matrix.csv")
+ svm["misattributions"].to_csv(f"{prefix}misattributions.csv")
- joblib.dump(svm["pipeline"], args.o+'mySVM.joblib')
+ joblib.dump(svm["pipeline"], f"{prefix}mySVM.joblib")
if args.final:
- print(".......... Writing final predictions to " + args.o + "FINAL_PREDICTIONS.csv ........")
- svm["final_predictions"].to_csv(args.o+"FINAL_PREDICTIONS.csv")
+ print(f".......... Writing final predictions to {prefix}FINAL_PREDICTIONS.csv ........")
+ svm["final_predictions"].to_csv(f"{prefix}FINAL_PREDICTIONS.csv")
- # If user requested rolling stylometry plot
if args.plot_rolling:
print(".......... Plotting rolling stylometry ........")
- smoothing = args.plot_smoothing if args.plot_smoothing is not None else 0
- superstyl.svm.plot_rolling(svm["final_predictions"], smoothing=smoothing)
+ superstyl.svm.plot_rolling(svm["final_predictions"], smoothing=args.plot_smoothing)
if args.get_coefs:
print(".......... Writing coefficients to disk ........")
- svm["coefficients"].to_csv(args.o+"coefficients.csv")
+ svm["coefficients"].to_csv(f"{prefix}coefficients.csv")
\ No newline at end of file