diff --git a/.coverage b/.coverage index cb8c5ebe..8380a2ac 100644 Binary files a/.coverage and b/.coverage differ diff --git a/load_corpus.py b/load_corpus.py index 668e7524..3936a8ac 100755 --- a/load_corpus.py +++ b/load_corpus.py @@ -1,117 +1,219 @@ from superstyl.load import load_corpus -from superstyl.load_from_config import load_corpus_from_config +from superstyl.config import Config import json # TODO: eliminate features that occur only n times ? # Do the Moisl Selection ? -# TODO: document the new 'lemma' feat for TEI loading if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser() - parser.add_argument('-s', nargs='+', help="paths to files or to json config file", required=True) - parser.add_argument('--json', action='store_true', help="indicates that the path provided with -s is a JSON config file, " - "containing all the options to load the corpus/features") - parser.add_argument('-o', action='store', help="optional base name of output files", type=str, default=False) - parser.add_argument('-f', action="store", help="optional list of features, either in json (generated by" - " Superstyl) or simple txt (one word per line)", default=False) - parser.add_argument('-t', action='store', help="types of features (words, chars, affixes - " - "as per Sapkota et al. 2015 -, as well as lemma or pos, met_line, " - "met_syll (those four last only for TEI files with proper annotation)" - , type=str, - default="words", choices=["words", "chars", "affixes", "pos", "lemma", "met_line", "met_syll"]) - parser.add_argument('-n', action='store', help="n grams lengths (default 1)", default=1, type=int) - parser.add_argument('-k', action='store', help="How many most frequent?", default=5000, type=int) - parser.add_argument('--freqs', action='store', help="relative, absolute or binarised freqs", + parser = argparse.ArgumentParser( + description="Load a corpus and extract features for stylometric analysis." + ) + parser.add_argument('-s', + nargs='+', + help="paths to files or to json config file", + required=True + ) + parser.add_argument('--json', + action='store_true', + help="indicates that the path provided with -s is a JSON config file, " + "containing all the options to load the corpus/features" + ) + parser.add_argument('-o', + action='store', + help="optional base name of output files", + type=str, + default=False + ) + # Feature list + parser.add_argument('-f', + action="store", + help="optional list of features, either in json (generated by" + " Superstyl) or simple txt (one word per line)", + default=False + ) + parser.add_argument('-t', + action='store', + help="types of features (words, chars, affixes - " + "as per Sapkota et al. 2015 -, as well as lemma or pos, met_line, " + "met_syll (those four last only for TEI files with proper annotation)", + type=str, default="words", + choices=["words", "chars", "affixes", "pos", "lemma", "met_line", "met_syll"] + ) + parser.add_argument('-n', + action='store', + help="n grams lengths (default 1)", + default=1, + type=int) + parser.add_argument('-k', + action='store', + help="How many most frequent features?", + default=5000, + type=int + ) + parser.add_argument('--freqs', + action='store', + help="relative, absolute or binarised freqs", default="relative", choices=["relative", "absolute", "binary"] ) - parser.add_argument('-x', action='store', help="format (txt, xml, tei, or txm) WARNING: only txt is fully implemented", - default="txt", + parser.add_argument('-x', + action='store', + help="format (txt, xml, tei, or txm) WARNING: only txt is fully implemented", + default="txt", choices=["txt", "xml", "tei", 'txm'] ) - parser.add_argument('--sampling', action='store_true', help="Sample the texts?", default=False) - parser.add_argument('--sample_units', action='store', help="Units of length for sampling " - "(words, verses; default: words)", - choices=["words", "verses"], - default="words", type=str) - parser.add_argument('--sample_size', action='store', help="Size for sampling (default: 3000)", default=3000, type=int) - parser.add_argument('--sample_step', action='store', help="Step for sampling with overlap (default is no overlap)", default=None, type=int) - parser.add_argument('--max_samples', action='store', help="Maximum number of (randomly selected) samples per class, e.g. author (default is all)", - default=None, type=int) - parser.add_argument('--samples_random', action='store_true', - help="Should random sampling with replacement be performed instead of continuous sampling (default: false)", - default=False) - parser.add_argument('--keep_punct', action='store_true', help="whether to keep punctuation and caps (default is False)", - default=False) - parser.add_argument('--keep_sym', action='store_true', - help="if true, same as keep_punct, plus no Unidecode, and numbers are kept as well (default is False)", - default=False) - parser.add_argument('--no_ascii', action='store_true', - help="disables the conversion to ascii as per the Unidecode module. Useful for non Latin alphabet (default is conversion to ASCII)", - default=False) - parser.add_argument('--identify_lang', action='store_true', - help="if true, should the language of each text be guessed, using langdetect (default is False)", - default=False) - parser.add_argument('--embedding', action="store", help="optional path to a word2vec embedding in txt format to compute frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)", - default=False) - parser.add_argument('--neighbouring_size', action="store", help="size of semantic neighbouring in the embedding (n closest neighbours)", - default=10, type=int) - parser.add_argument('--culling', action="store", - help="percentage value for culling, meaning in what percentage of samples should a feature be present to be retained (default is 0, meaning no culling)", - default=0, type=float) - + parser.add_argument('--sampling', + action='store_true', + help="Sample the texts?", + default=False + ) + parser.add_argument('--sample_units', + action='store', + help="Units of length for sampling (words, verses; default: words)", + choices=["words", "verses"], + default="words", + type=str + ) + parser.add_argument('--sample_size', + action='store', + help="Size for sampling (default: 3000)", + default=3000, + type=int + ) + parser.add_argument('--sample_step', + action='store', + help="Step for sampling with overlap (default is no overlap)", + default=None, + type=int + ) + parser.add_argument('--max_samples', + action='store', + help="Maximum number of (randomly selected) samples per class, e.g. author (default is all)", + default=None, + type=int + ) + parser.add_argument('--samples_random', + action='store_true', + help="Should random sampling with replacement be performed " \ + "instead of continuous sampling (default: false)", + default=False + ) + parser.add_argument('--keep_punct', + action='store_true', + help="whether to keep punctuation and caps (default is False)", + default=False + ) + parser.add_argument('--keep_sym', + action='store_true', + help="if true, same as keep_punct, plus no Unidecode, " + "and numbers are kept as well (default is False)", + default=False + ) + parser.add_argument('--no_ascii', + action='store_true', + help="disables the conversion to ascii as per the Unidecode module. " \ + "Useful for non Latin alphabet (default is conversion to ASCII)", + default=False + ) + parser.add_argument('--identify_lang', + action='store_true', + help="if true, should the language of each text be guessed, " \ + "using langdetect (default is False)", + default=False + ) + parser.add_argument('--embedding', + action="store", + help="optional path to a word2vec embedding in txt format to compute " \ + "frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)", + default=False + ) + parser.add_argument('--neighbouring_size', + action="store", + help="size of semantic neighbouring in the embedding (n closest neighbours)", + default=10, + type=int + ) + parser.add_argument('--culling', + action="store", + help="percentage value for culling, meaning in what " \ + "percentage of samples should a feature be present " \ + "to be retained (default is 0, meaning no culling)", + default=0, + type=float) args = parser.parse_args() + # Load feature list if provided + my_feats = None if args.f: with open(args.f, 'r') as f: - if args.f.split(".")[-1] == "json": + if args.f.endswith(".json"): print(".......loading preexisting feature list from json.......") my_feats = json.loads(f.read()) - - elif args.f.split(".")[-1] == "txt": + elif args.f.endswith(".txt"): print(".......loading preexisting feature list from txt.......") my_feats = [[feat.rstrip(), 0] for feat in f.readlines()] - else: print(".......unknown feature list format. Ignoring.......") - my_feats = None - - else: - my_feats = None - if args.json: - corpus, my_feats = load_corpus_from_config(args.s) + elif args.config: + # Load from new-style JSON config file + config = Config.from_json(args.config) + # Override paths if provided via CLI + if args.s: + config.corpus.paths = args.s + else: - corpus, my_feats = load_corpus(args.s, feat_list=my_feats, feats=args.t, n=args.n, k=args.k, - freqsType=args.freqs, format=args.x, - sampling=args.sampling, units=args.sample_units, - size=args.sample_size, step=args.sample_step, max_samples=args.max_samples, - samples_random=args.samples_random, - keep_punct=args.keep_punct, keep_sym=args.keep_sym, no_ascii=args.no_ascii, - identify_lang=args.identify_lang, - embedding=args.embedding, neighbouring_size=args.neighbouring_size, - culling=args.culling - ) - - print(".......saving results.......") - + if not args.s: + parser.error("-s (paths) is required when not using --config") + + config = Config.from_kwargs( + data_paths=args.s, + feats=args.t, + n=args.n, + k=args.k, + freqsType=args.freqs, + format=args.x, + sampling=args.sampling, + units=args.sample_units, + size=args.sample_size, + step=args.sample_step, + max_samples=args.max_samples, + samples_random=args.samples_random, + keep_punct=args.keep_punct, + keep_sym=args.keep_sym, + no_ascii=args.no_ascii, + identify_lang=args.identify_lang, + embedding=args.embedding, + neighbouring_size=args.neighbouring_size, + culling=args.culling + ) + + # Inject my_feats if provided + if my_feats and config.features: + config.features[0].feat_list = my_feats + + # Load corpus + corpus, my_feats = load_corpus(config=config) + + # Determine output file names if args.o: feat_file = args.o + "_feats.json" corpus_file = args.o + ".csv" - else: - feat_file = "feature_list_{}{}grams{}mf.json".format(args.t, args.n, args.k) - corpus_file = "feats_tests_n{}_k_{}.csv".format(args.n, args.k) + feat_file = f"feature_list_{args.t}{args.n}grams{args.k}mf.json" + corpus_file = f"feats_tests_n{args.n}_k_{args.k}.csv" - #if not args.f and : + # Save results + print(".......saving results.......") + with open(feat_file, "w") as out: out.write(json.dumps(my_feats, ensure_ascii=False, indent=0)) - print("Features list saved to " + feat_file) + print(f"Features list saved to {feat_file}") + # Save corpus corpus.to_csv(corpus_file) - print("Corpus saved to " + corpus_file) - - + print(f"Corpus saved to {corpus_file}") \ No newline at end of file diff --git a/split.py b/split.py index 6ee63c9d..3094fa90 100755 --- a/split.py +++ b/split.py @@ -1,3 +1,7 @@ +""" +Command-line tool for splitting datasets. +""" + import superstyl.preproc.select as sel @@ -12,28 +16,28 @@ default=False) parser.add_argument('-m', action="store", help="path to metadata file", required=False) parser.add_argument('-e', action="store", help="path to excludes file", required=False) - parser.add_argument('--lang', action="store", help="analyse only file in this language (optional, for initial split only)", required=False) - parser.add_argument('--nosplit', action="store_true", help="no split (do not provide split file)", default=False) + parser.add_argument('--lang', action="store", + help="analyse only file in this language (optional, for initial split only)", + required=False) + parser.add_argument('--nosplit', action="store_true", + help="no split (do not provide split file)", + default=False) + parser.add_argument('--split_ratio', action="store", type=float, + help="validation split ratio (default: 0.1 = 10%%)", + default=0.1) args = parser.parse_args() - if args.nosplit: - sel.read_clean(path=args.path, - metadata_path=args.m, - excludes_path=args.e, - savesplit="split_nosplit.json", - lang=args.lang - ) + if args.s: + # Apply existing selection + sel.apply_selection(path=args.path, presplit_path=args.s) else: - - if not args.s: - # to create initial selection - sel.read_clean_split(path=args.path, - metadata_path=args.m, - excludes_path=args.e, - savesplit="split.json", - lang=args.lang - ) - - else: - # to load and apply a selection - sel.apply_selection(path=args.path, presplit_path=args.s) + # Create new selection (with or without split) + sel.read_clean( + path=args.path, + metadata_path=args.m, + excludes_path=args.e, + savesplit="split_nosplit.json" if args.nosplit else "split.json", + lang=args.lang, + split=not args.nosplit, + split_ratio=args.split_ratio + ) diff --git a/superstyl/__init__.py b/superstyl/__init__.py index 940f61fc..67ab2010 100755 --- a/superstyl/__init__.py +++ b/superstyl/__init__.py @@ -1,2 +1,36 @@ from superstyl.load import load_corpus -from superstyl.svm import train_svm \ No newline at end of file +from superstyl.svm import train_svm, plot_rolling, plot_coefficients +from superstyl.config import ( + Config, + CorpusConfig, + FeatureConfig, + SamplingConfig, + NormalizationConfig, + SVMConfig, + get_config, + set_config, + reset_config +) + +__all__ = [ + # Main functions + 'load_corpus', + 'load_corpus_with_config', + 'train_svm', + 'train_svm_with_config', + 'plot_rolling', + 'plot_coefficients', + + # Configuration classes + 'Config', + 'CorpusConfig', + 'FeatureConfig', + 'SamplingConfig', + 'NormalizationConfig', + 'SVMConfig', + + # Configuration management + 'get_config', + 'set_config', + 'reset_config', +] \ No newline at end of file diff --git a/superstyl/config.py b/superstyl/config.py new file mode 100644 index 00000000..4a0aa6fa --- /dev/null +++ b/superstyl/config.py @@ -0,0 +1,362 @@ +from dataclasses import dataclass, field, fields +from typing import Optional, List, Any, Dict, Type, TypeVar +import json + + +T = TypeVar('T', bound='BaseConfig') + + +@dataclass +class BaseConfig: + """ + Base configuration class providing common functionality. + + All configuration classes inherit from this to share: + - to_dict() serialization + - from_dict() deserialization + - Validation hooks + """ + + def to_dict(self) -> Dict[str, Any]: + result = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, BaseConfig): + result[f.name] = value.to_dict() + elif isinstance(value, list) and value and isinstance(value[0], BaseConfig): + result[f.name] = [v.to_dict() for v in value] + else: + result[f.name] = value + return result + + @classmethod + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: + return cls(**data) + + def validate(self) -> None: + pass + + +@dataclass +class NormalizationConfig(BaseConfig): + """ + Configuration for text normalization. + """ + keep_punct: bool = False + keep_sym: bool = False + no_ascii: bool = False + + +@dataclass +class SamplingConfig(BaseConfig): + """ + Configuration for text sampling. + """ + enabled: bool = False + units: str = "words" + size: int = 3000 + step: Optional[int] = None + max_samples: Optional[int] = None + random: bool = False + + def __post_init__(self): + self.validate() + + def validate(self) -> None: + if self.units not in ["words", "verses"]: + raise ValueError(f"Invalid sampling units: {self.units}.") + if self.random and self.step is not None: + raise ValueError("Random sampling is not compatible with step.") + if self.random and self.max_samples is None: + raise ValueError("Random sampling needs max_samples.") + + +@dataclass +class FeatureConfig(BaseConfig): + """ + Configuration for feature extraction. + """ + name: Optional[str] = None # For multi-feature identification + type: str = "words" + n: int = 1 + k: int = 5000 + freq_type: str = "relative" + feat_list: Optional[List] = None + feat_list_path: Optional[str] = None # Path to load feat_list from + embedding: Optional[str] = None + neighbouring_size: int = 10 + culling: float = 0 + + VALID_TYPES = ["words", "chars", "affixes", "lemma", "pos", "met_line", "met_syll"] + VALID_FREQ_TYPES = ["relative", "absolute", "binary"] + + def __post_init__(self): + self.validate() + self._load_feat_list_if_needed() + + def validate(self) -> None: + if self.type not in self.VALID_TYPES: + raise ValueError(f"Invalid feature type: {self.type}.") + if self.freq_type not in self.VALID_FREQ_TYPES: + raise ValueError(f"Invalid frequency type: {self.freq_type}.") + if self.n < 1: + raise ValueError("n must be a positive integer.") + + def _load_feat_list_if_needed(self) -> None: + """ + Load feature list from file if path is specified. + """ + if self.feat_list_path and self.feat_list is None: + with open(self.feat_list_path, 'r') as f: + if self.feat_list_path.endswith('.json'): + self.feat_list = json.load(f) + elif self.feat_list_path.endswith('.txt'): + self.feat_list = [[feat.strip(), 0] for feat in f.readlines()] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'FeatureConfig': + # Filter out unknown keys for backward compatibility + valid_keys = {f.name for f in fields(cls)} + filtered_data = {k: v for k, v in data.items() if k in valid_keys} + return cls(**filtered_data) + + +@dataclass +class CorpusConfig(BaseConfig): + """ + Configuration for corpus loading. + """ + paths: List[str] = field(default_factory=list) + format: str = "txt" + identify_lang: bool = False + + VALID_FORMATS = ["txt", "xml", "tei", "txm"] + + def __post_init__(self): + self.validate() + + def validate(self) -> None: + if self.format not in self.VALID_FORMATS: + raise ValueError(f"Invalid format: {self.format}.") + + +@dataclass +class SVMConfig(BaseConfig): + """ + Configuration for SVM training. + """ + cross_validate: Optional[str] = None + k: int = 0 + dim_reduc: Optional[str] = None + norms: bool = True + balance: Optional[str] = None + class_weights: bool = False + kernel: str = "LinearSVC" + final_pred: bool = False + get_coefs: bool = False + plot_rolling: bool = False + plot_smoothing: int = 3 + + VALID_CV = [None, "leave-one-out", "k-fold", "group-k-fold"] + VALID_DIM_REDUC = [None, "pca"] + VALID_BALANCE = [None, "downsampling", "Tomek", "upsampling", "SMOTE", "SMOTETomek"] + VALID_KERNELS = ["LinearSVC", "linear", "sigmoid", "rbf", "poly"] + + def __post_init__(self): + self.validate() + + def validate(self) -> None: + if self.cross_validate not in self.VALID_CV: + raise ValueError(f"Invalid cross_validate: {self.cross_validate}.") + if self.dim_reduc not in self.VALID_DIM_REDUC: + raise ValueError(f"Invalid dim_reduc: {self.dim_reduc}.") + if self.balance not in self.VALID_BALANCE: + raise ValueError(f"Invalid balance: {self.balance}.") + if self.kernel not in self.VALID_KERNELS: + raise ValueError(f"Invalid kernel: {self.kernel}.") + + +@dataclass +class Config(BaseConfig): + """ + Main configuration class for SuperStyl. + + Aggregates all sub-configurations and provides factory methods + to create configurations from various sources (CLI, JSON, dict). + """ + corpus: CorpusConfig = field(default_factory=CorpusConfig) + features: List[FeatureConfig] = field(default_factory=lambda: [FeatureConfig()]) + sampling: SamplingConfig = field(default_factory=SamplingConfig) + normalization: NormalizationConfig = field(default_factory=NormalizationConfig) + svm: SVMConfig = field(default_factory=SVMConfig) + output_prefix: Optional[str] = None + + # Mapping from flat kwargs to nested config structure + # Format: 'kwarg_name': ('section', 'attr', optional_transform) + KWARGS_MAPPING = { + # Corpus + 'data_paths': ('corpus', 'paths', lambda x: x if isinstance(x, list) else [x]), + 'format': ('corpus', 'format', None), + 'identify_lang': ('corpus', 'identify_lang', None), + + # Features (single feature mode) + 'feats': ('features', 'type', None), + 'n': ('features', 'n', None), + 'k': ('features', 'k', None), + 'freqsType': ('features', 'freq_type', None), + 'feat_list': ('features', 'feat_list', None), + 'embedding': ('features', 'embedding', lambda x: x if x else None), + 'neighbouring_size': ('features', 'neighbouring_size', None), + 'culling': ('features', 'culling', None), + + # Sampling + 'sampling': ('sampling', 'enabled', None), + 'units': ('sampling', 'units', None), + 'size': ('sampling', 'size', None), + 'step': ('sampling', 'step', None), + 'max_samples': ('sampling', 'max_samples', None), + 'samples_random': ('sampling', 'random', None), + + # Normalization + 'keep_punct': ('normalization', 'keep_punct', None), + 'keep_sym': ('normalization', 'keep_sym', None), + 'no_ascii': ('normalization', 'no_ascii', None), + + # SVM + 'cross_validate': ('svm', 'cross_validate', None), + 'dim_reduc': ('svm', 'dim_reduc', None), + 'norms': ('svm', 'norms', None), + 'balance': ('svm', 'balance', None), + 'class_weights': ('svm', 'class_weights', None), + 'kernel': ('svm', 'kernel', None), + 'final_pred': ('svm', 'final_pred', None), + 'get_coefs': ('svm', 'get_coefs', None), + } + + def validate(self) -> None: + """ + Validate configuration consistency. + """ + if not self.features: + raise ValueError("No features specified for extraction.") + + if not self.corpus.paths: + raise ValueError("No paths specified for corpus loading.") + + # Validate paths type + if not isinstance(self.corpus.paths, list): + raise TypeError("Paths in config must be either a list or a glob pattern string.") + + for feat_config in self.features: + tei_only = ["lemma", "pos", "met_line", "met_syll"] + if feat_config.type in tei_only and self.corpus.format != "tei": + raise ValueError(f"{feat_config.type} requires TEI format.") + if feat_config.type in ["met_line", "met_syll"] and self.sampling.units not in ["verses"]: + raise ValueError(f"{feat_config.type} requires verses units.") + + def save(self, path: str) -> None: + with open(path, 'w') as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + @classmethod + def from_kwargs(cls, **kwargs) -> 'Config': + """ + Build Config from flat kwargs (backward compatibility). + + This allows the old-style function calls to work: + load_corpus(paths, feats="chars", n=3, keep_punct=True) + """ + config_data = { + 'corpus': {}, + 'features': {}, + 'sampling': {}, + 'normalization': {}, + 'svm': {} + } + + for kwarg_name, value in kwargs.items(): + if value is None: + continue + + if kwarg_name in cls.KWARGS_MAPPING: + section, attr, transform = cls.KWARGS_MAPPING[kwarg_name] + final_value = transform(value) if transform else value + config_data[section][attr] = final_value + + # Build the config with proper nesting + return cls.from_dict(config_data) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Config': + kwargs = {} + + if 'corpus' in data: + corpus_data = data['corpus'].copy() + # Handle 'paths' key variations + if 'paths' not in corpus_data and 'data_paths' in data: + corpus_data['paths'] = data['data_paths'] + kwargs['corpus'] = CorpusConfig.from_dict(corpus_data) if corpus_data else CorpusConfig() + + if 'features' in data: + features_data = data['features'] + if isinstance(features_data, dict): + kwargs['features'] = [FeatureConfig.from_dict(features_data)] + elif isinstance(features_data, list): + if features_data: + kwargs['features'] = [FeatureConfig.from_dict(f) for f in features_data] + else: + kwargs['features'] = [] + + if 'sampling' in data and data['sampling']: + sampling_data = data['sampling'].copy() + # Handle alternative key names + if 'sample_size' in sampling_data: + sampling_data['size'] = sampling_data.pop('sample_size') + if 'samples_random' in sampling_data: + sampling_data['random'] = sampling_data.pop('samples_random') + if 'sample_step' in sampling_data: + sampling_data['step'] = sampling_data.pop('sample_step') + if 'sample_units' in sampling_data: + sampling_data['units'] = sampling_data.pop('sample_units') + kwargs['sampling'] = SamplingConfig.from_dict(sampling_data) + + if 'normalization' in data and data['normalization']: + kwargs['normalization'] = NormalizationConfig.from_dict(data['normalization']) + + if 'svm' in data and data['svm']: + kwargs['svm'] = SVMConfig.from_dict(data['svm']) + + if 'output_prefix' in data: + kwargs['output_prefix'] = data['output_prefix'] + + return cls(**kwargs) + + @classmethod + def from_json(cls, path: str) -> 'Config': + with open(path, 'r') as f: + data = json.load(f) + + # Handle flat JSON format (paths at root level) + if 'paths' in data and 'corpus' not in data: + data['corpus'] = {'paths': data.pop('paths')} + if 'format' in data: + data['corpus']['format'] = data.pop('format') + if 'identify_lang' in data: + data['corpus']['identify_lang'] = data.pop('identify_lang') + + return cls.from_dict(data) + + +# Global configuration (optional singleton pattern) +_current_config: Optional[Config] = None + +def get_config() -> Optional[Config]: + return _current_config + +def set_config(config: Config) -> None: + global _current_config + _current_config = config + +def reset_config() -> None: + global _current_config + _current_config = None \ No newline at end of file diff --git a/superstyl/load.py b/superstyl/load.py index fe52683e..395b62ae 100644 --- a/superstyl/load.py +++ b/superstyl/load.py @@ -4,133 +4,208 @@ import superstyl.preproc.embedding as embed import tqdm import pandas +from typing import Optional, List, Tuple, Union +from superstyl.config import Config, FeatureConfig, NormalizationConfig -def load_corpus(data_paths, feat_list=None, feats="words", n=1, k=5000, freqsType="relative", format="txt", sampling=False, - units="words", size=3000, step=None, max_samples=None, samples_random=False, keep_punct=False, keep_sym=False, - no_ascii=False, - identify_lang=False, embedding=False, neighbouring_size=10, culling=0): + +def _load_single_feature( + myTexts: List[dict], + feat_config: FeatureConfig, + norm_config: NormalizationConfig, + use_provided_feat_list: bool = False, +) -> Tuple[pandas.DataFrame, List]: """ - Main function to load a corpus from a collection of file, and an optional list of features to extract. - :param data_paths: paths to the source files - :param feat_list: an optional list of features (as created by load_corpus), default None - :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'. - Affixes are inspired by Sapkota et al. 2015, and include space_prefix, space_suffix, prefix, suffix, and, - if keep_punct, punctuation n-grams. From TEI, pos, lemma, met_line or met_syll can - be extracted; met_line is the prosodic (stress) annotation of a full verse; met_syll is a char n-gram of prosodic - annotation - :param n: n grams lengths (default 1) - :param k: How many most frequent? The function takes the rank of k (if k is smaller than the total number of features), - gets its frequencies, and only include features of superior or equal total frequencies. - :param freqsType: return relative, absolute or binarised frequencies (default: relative) - :param format: one of txt, xml or tei. /!\ only txt is fully implemented. - :param sampling: whether to sample the texts, by cutting it into slices of a given length, until the last possible - slice of this length, which means that often the end of the text will be eliminated (default False) - :param units: units of length for sampling, one of 'words', 'verses' (default: words). 'verses' is only implemented - for the 'tei' format - :param size: the size of the samples (in units) - :param step: step for sampling with overlap (default is step = size, which means no overlap). - Reduce for overlapping slices - :param max_samples: Maximum number of (randomly selected) samples per author/class (default is all) - :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false) - :param keep_punct: whether to keep punctuation and caps (default is False) - :param keep_sym: same as keep_punct, and numbers are kept as well (default is False). /!\ does not - actually keep symbols - :param no_ascii: disables conversion to ASCII (default is conversion) - :param identify_lang: if true, the language of each text will be guessed, using langdetect (default is False) - :param embedding: optional path to a word2vec embedding in txt format to compute frequencies among a set of - semantic neighbourgs (i.e., pseudo-paronyms) - :param neighbouring_size: size of semantic neighbouring in the embedding (as per gensim most_similar, - with topn=neighbouring_size) - :param culling percentage value for culling, meaning in what percentage of samples should a feature be present to be retained (default is 0, meaning no culling) - :return a pandas dataFrame of text metadata and feature frequencies; a global list of features with their frequencies + Extract features for a single FeatureConfig. + Internal function used by load_corpus. + + Args: + use_provided_feat_list: If True and feat_config.feat_list is provided, + return that list instead of the computed one. Used for test sets + to ensure same features as training set. """ - - if feats in ('lemma', 'pos', 'met_line', 'met_syll') and format != 'tei': - raise ValueError("lemma, pos, met_line or met_syll are only possible with adequate tei format (@lemma, @pos, @met)") - - if feats in ('met_line', 'met_syll') and units != 'lines': - raise ValueError("met_line or met_syll are only possible with tei format that includes lines and @met") + feats = feat_config.type + n = feat_config.n + k = feat_config.k + freqsType = feat_config.freq_type + provided_feat_list = feat_config.feat_list + embedding = feat_config.embedding + neighbouring_size = feat_config.neighbouring_size + culling = feat_config.culling embeddedFreqs = False if embedding: print(".......loading embedding.......") - relFreqs = False # we need absolute freqs as a basis for embedded frequencies model = embed.load_embeddings(embedding) embeddedFreqs = True - freqsType = "absolute" #absolute freqs are required for embedding + freqsType = "absolute" - print(".......loading texts.......") - - if sampling: - myTexts = pipe.docs_to_samples(data_paths, feats=feats, format=format, units=units, size=size, step=step, - max_samples=max_samples, samples_random=samples_random, - keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii, - identify_lang = identify_lang) + print(f".......getting features ({feats}, n={n}).......") - else: - myTexts = pipe.load_texts(data_paths, feats=feats, format=format, max_samples=max_samples, keep_punct=keep_punct, - keep_sym=keep_sym, no_ascii=no_ascii, identify_lang=identify_lang) - - print(".......getting features.......") - - if feat_list is None: + if provided_feat_list is None: feat_list = fex.get_feature_list(myTexts, feats=feats, n=n, freqsType=freqsType) if k > len(feat_list): - print("K Limit ignored because the size of the list is lower ({} < {})".format(len(feat_list), k)) + print(f"K limit ignored ({len(feat_list)} < {k})") else: - # and now, cut at around rank k val = feat_list[k-1][1] feat_list = [m for m in feat_list if m[1] >= val] - + else: + feat_list = provided_feat_list print(".......getting counts.......") - my_feats = [m[0] for m in feat_list] # keeping only the features without the frequencies - myTexts = fex.get_counts(myTexts, feat_list=my_feats, feats=feats, n=n, freqsType=freqsType) + my_feats = [m[0] for m in feat_list] + # Copy myTexts to avoid mutating original for multi-feature + texts_copy = [dict(t) for t in myTexts] + texts_copy = fex.get_counts(texts_copy, feat_list=my_feats, feats=feats, n=n, freqsType=freqsType) + if embedding: print(".......embedding counts.......") - myTexts, my_feats = embed.get_embedded_counts(myTexts, my_feats, model, topn=neighbouring_size) + texts_copy, my_feats = embed.get_embedded_counts(texts_copy, my_feats, model, topn=neighbouring_size) feat_list = [f for f in feat_list if f[0] in my_feats] - unique_texts = [text["name"] for text in myTexts] - if culling > 0: - print(".......Culling at " + str(culling) + "%.......") - # Counting in how many sample the feat appear - feats_doc_freq = fex.get_doc_frequency(myTexts) - # Now selecting - my_feats = [f for f in my_feats if (feats_doc_freq[f] / len(myTexts) * 100) > culling] + print(f".......Culling at {culling}%.......") + feats_doc_freq = fex.get_doc_frequency(texts_copy) + my_feats = [f for f in my_feats if (feats_doc_freq[f] / len(texts_copy) * 100) > culling] feat_list = [f for f in feat_list if f[0] in my_feats] print(".......feeding data frame.......") loc = {} - - for t in tqdm.tqdm(myTexts): + for t in tqdm.tqdm(texts_copy): text, local_freqs = count_process((t, my_feats), embeddedFreqs=embeddedFreqs) loc[text["name"]] = local_freqs - # Saving metadata for later - metadata = pandas.DataFrame(columns=['author', 'lang'], index=unique_texts, data= - [[t["aut"], t["lang"]] for t in myTexts]) + feats_df = pandas.DataFrame.from_dict(loc, columns=list(my_feats), orient="index") - # Free some space before doing this... - del myTexts + # For test sets: return the provided feat_list unchanged + if use_provided_feat_list and provided_feat_list is not None: + return feats_df, provided_feat_list + + return feats_df, feat_list - # frequence based selection - # WOW, pandas is a great tool, almost as good as using R - # But confusing as well: boolean selection works on rows by default - # were elsewhere it works on columns - # take only rows where the number of values above 0 is superior to two - # (i.e. appears in at least two texts) - #feats = feats.loc[:, feats[feats > 0].count() > 2] - feats = pandas.DataFrame.from_dict(loc, columns=list(my_feats), orient="index") +def load_corpus( + config: Optional[Config] = None, + use_provided_feat_list: bool = False, + **kwargs +) -> Tuple[pandas.DataFrame, Union[List, List[List]]]: + """ + Load a corpus and extract features. + + Can be called with: + 1. A Config object: load_corpus(config=my_config) + 2. Individual parameters (backward compatible): + load_corpus(data_paths=paths, feats="chars", n=3) + + Args: + config: Configuration object. If None, built from kwargs. + use_provided_feat_list: If True and feat_list provided, return it unchanged. + Use for test sets to match training features. + **kwargs: Individual parameters for backward compatibility. + Supported: data_paths, feat_list, feats, n, k, freqsType, + format, sampling, units, size, step, max_samples, samples_random, + keep_punct, keep_sym, no_ascii, identify_lang, embedding, + neighbouring_size, culling + + Returns: + - If single feature: (DataFrame, feat_list) + - If multiple features: (DataFrame with prefixed columns, list of feat_lists) + """ + # Build config from kwargs if not provided + if config is None: + config = Config.from_kwargs(**kwargs) + + # Validate configuration + config.validate() + data_paths = config.corpus.paths + + # Handle string paths (single file or glob pattern) + if isinstance(data_paths, str): + import glob + # If it's a glob pattern, expand it + if '*' in data_paths or '?' in data_paths: + data_paths = sorted(glob.glob(data_paths)) + else: + # Single file path - wrap in list + data_paths = [data_paths] + + # Validate + for feat_config in config.features: + if feat_config.type in ('lemma', 'pos', 'met_line', 'met_syll') and config.corpus.format != 'tei': + raise ValueError(f"{feat_config.type} requires TEI format.") + if feat_config.type in ('met_line', 'met_syll') and config.sampling.units != 'verses': + raise ValueError(f"{feat_config.type} verses lines units.") + data_paths = config.corpus.paths + + # Handle string paths (single file or glob pattern) + if isinstance(data_paths, str): + import glob + # If it's a glob pattern, expand it + if '*' in data_paths or '?' in data_paths: + data_paths = sorted(glob.glob(data_paths)) + else: + # Single file path - wrap in list + data_paths = [data_paths] - # Free some more - del loc + # Validate + for feat_config in config.features: + if feat_config.type in ('lemma', 'pos', 'met_line', 'met_syll') and config.corpus.format != 'tei': + raise ValueError(f"{feat_config.type} requires TEI format.") + if feat_config.type in ('met_line', 'met_syll') and config.sampling.units != 'verses': + raise ValueError(f"{feat_config.type} requires verses units.") - corpus = pandas.concat([metadata, feats], axis=1) + # Load texts once + print(".......loading texts.......") - return corpus, feat_list \ No newline at end of file + if config.sampling.enabled: + myTexts = pipe.docs_to_samples( + data_paths, + config=config + ) + else: + myTexts = pipe.load_texts( + data_paths, + config=config + ) + + unique_texts = [text["name"] for text in myTexts] + + # Build metadata + metadata = pandas.DataFrame( + columns=['author', 'lang'], + index=unique_texts, + data=[[t["aut"], t["lang"]] for t in myTexts] + ) + + # Single feature case + if len(config.features) == 1: + feat_config = config.features[0] + feats_df, feat_list = _load_single_feature( + myTexts, feat_config, config.normalization, use_provided_feat_list + ) + corpus = pandas.concat([metadata, feats_df], axis=1) + return corpus, feat_list + + # Multiple features case + print(f".......extracting {len(config.features)} feature sets.......") + + all_feat_lists = [] + merged_feats = metadata.copy() + + for i, feat_config in enumerate(config.features): + prefix = feat_config.name or f"f{i+1}" + print(f".......processing {prefix}.......") + + feats_df, feat_list = _load_single_feature( + myTexts, feat_config, config.normalization, use_provided_feat_list + ) + + # Prefix columns to avoid collisions + feats_df = feats_df.rename(columns={col: f"{prefix}_{col}" for col in feats_df.columns}) + + merged_feats = pandas.concat([merged_feats, feats_df], axis=1) + all_feat_lists.append(feat_list) + + return merged_feats, all_feat_lists \ No newline at end of file diff --git a/superstyl/load_from_config.py b/superstyl/load_from_config.py deleted file mode 100644 index 4bc7cca4..00000000 --- a/superstyl/load_from_config.py +++ /dev/null @@ -1,197 +0,0 @@ -import json -import pandas as pd -import os -import glob - -from superstyl.load import load_corpus - -def load_corpus_from_config(config_path, is_test=False): - """ - Load a corpus based on a JSON configuration file. - - Parameters: - ----------- - config_path : str - Path to the JSON configuration file - - Returns: - -------- - tuple: (corpus, feat_list) - Same format as load_corpus function - If multiple features are defined, returns the merged corpus and the combined feature list - If only one feature is defined, returns that corpus and its feature list - """ - # Load configuration - if not config_path.endswith('.json'): - raise ValueError(f"Unsupported configuration file format: {config_path}. Only JSON format is supported.") - - with open(config_path, 'r') as f: - config = json.load(f) - - # Get corpus paths - - if 'paths' in config: - if isinstance(config['paths'], list): - paths = [] - for path in config['paths']: - if '*' in path or '?' in path or '[' in path: - expanded_paths = glob.glob(path) - if not expanded_paths: - print(f"Warning: No files found for pattern '{path}'") - paths.extend(expanded_paths) - else: - paths.append(path) - elif isinstance(config['paths'], str): - if '*' in config['paths'] or '?' in config['paths'] or '[' in config['paths']: - paths = glob.glob(config['paths']) - if not paths: - raise ValueError(f"No files found for glob pattern '{config['paths']}'") - else: - paths = [config['paths']] - else: - raise ValueError("Paths in config must be either a list or a glob pattern string") - else: - raise ValueError("No paths provided and no paths found in config") - - # Get sampling parameters - sampling_params = config.get('sampling', {}) - - # Use the first feature to create the base corpus with sampling - feature_configs = config.get('features', []) - if not feature_configs: - raise ValueError("No features specified in the configuration") - - # If there's only one feature, we can simply return the result of load_corpus - if len(feature_configs) == 1: - feature_config = feature_configs[0] - feature_name = feature_config.get('name', "f1") - - # Check for feature list file - feat_list = None - feat_list_path = feature_config.get('feat_list') - if feat_list_path: - if feat_list_path.endswith('.json'): - with open(feat_list_path, 'r') as f: - feat_list = json.load(f) - elif feat_list_path.endswith('.txt'): - with open(feat_list_path, 'r') as f: - feat_list = [[feat.strip(), 0] for feat in f.readlines()] - - # Set up other parameters - params = { - 'feats': feature_config.get('type', 'words'), - 'n': feature_config.get('n', 1), - 'k': feature_config.get('k', 5000), - 'freqsType': feature_config.get('freq_type', 'relative'), - 'format': config.get('format', 'txt'), - 'sampling': sampling_params.get('enabled', False), - 'units': sampling_params.get('units', 'words'), - 'size': sampling_params.get('sample_size', 3000), - 'step': sampling_params.get('step', None), - 'max_samples': sampling_params.get('max_samples', None), - 'samples_random': sampling_params.get('samples_random', False), - 'keep_punct': feature_config.get('keep_punct', False), - 'keep_sym': feature_config.get('keep_sym', False), - 'no_ascii': feature_config.get('no_ascii', False), - 'identify_lang': feature_config.get('identify_lang', False), - 'embedding': feature_config.get('embedding', None), - 'neighbouring_size': feature_config.get('neighbouring_size', 10), - 'culling': feature_config.get('culling', 0) - } - - print(f"Loading corpus with {feature_name}...") - corpus, features = load_corpus(paths, feat_list=feat_list, **params) - - return corpus, features - - # For multiple features, we need to process each one and merge the results - corpora = {} - feature_lists = {} - - # Process each feature configuration - for i, feature_config in enumerate(feature_configs): - feature_name = feature_config.get('name', f"f{i+1}") - - # Check for feature list file - feat_list = None - feat_list_path = feature_config.get('feat_list') - print(feat_list_path) - if feat_list_path: - if feat_list_path.endswith('.json'): - with open(feat_list_path, 'r') as f: - feat_list = json.load(f) - elif feat_list_path.endswith('.txt'): - with open(feat_list_path, 'r') as f: - feat_list = [[feat.strip(), 0] for feat in f.readlines()] - - # Set up other parameters - params = { - 'feats': feature_config.get('type', 'words'), - 'n': feature_config.get('n', 1), - 'k': feature_config.get('k', 5000), - 'freqsType': feature_config.get('freq_type', 'relative'), - 'format': config.get('format', 'txt'), - 'sampling': sampling_params.get('enabled', False), - 'units': sampling_params.get('units', 'words'), - 'size': sampling_params.get('sample_size', 3000), - 'step': sampling_params.get('step', None), - 'max_samples': sampling_params.get('max_samples', None), - 'samples_random': sampling_params.get('samples_random', False), - 'keep_punct': config.get('keep_punct', False), - 'keep_sym': config.get('keep_sym', False), - 'no_ascii': config.get('no_ascii', False), - 'identify_lang': config.get('identify_lang', False), - 'embedding': feature_config.get('embedding', None), - 'neighbouring_size': feature_config.get('neighbouring_size', 10), - 'culling': feature_config.get('culling', 0) - } - - print(f"Loading {feature_name}...") - - corpus, features = load_corpus(paths, feat_list=feat_list, **params) - - # Store corpus and features - corpora[feature_name] = corpus - - if feat_list is not None and is_test: - feature_lists[feature_name] = feat_list - else: - feature_lists[feature_name] = features - - - # Create a merged dataset - print("Creating merged dataset...") - first_corpus_name = next(iter(corpora)) - - # Start with metadata from the first corpus - metadata = corpora[first_corpus_name][['author', 'lang']] - - # Create an empty DataFrame for the merged corpus - merged = pd.DataFrame(index=metadata.index) - - # Add metadata - merged = pd.concat([metadata, merged], axis=1) - - # Combine all features with prefixes to avoid name collisions - all_features = [] - - # Add features from each corpus - for name, corpus in corpora.items(): - single_feature = [] - - feature_cols = [col for col in corpus.columns if col not in ['author', 'lang']] - - # Rename columns to avoid duplicates - renamed_cols = {col: f"{name}_{col}" for col in feature_cols} - feature_df = corpus[feature_cols].rename(columns=renamed_cols) - - # Merge with the main DataFrame - merged = pd.concat([merged, feature_df], axis=1) - - # Add features to the combined list with prefixes - for feature in feature_lists[name]: - single_feature.append((feature[0], feature[1])) - - all_features.append(single_feature) - # Return the merged corpus and combined feature list - return merged, all_features - diff --git a/superstyl/preproc/pipe.py b/superstyl/preproc/pipe.py index 7c675b68..3812a43e 100755 --- a/superstyl/preproc/pipe.py +++ b/superstyl/preproc/pipe.py @@ -1,145 +1,53 @@ from lxml import etree -import regex as re -import unidecode import nltk.tokenize import random -import langdetect -import unicodedata - -def XML_to_text(path): - """ - Get main text from xml file - :param path: path to the file to transform - :return: a tuple with auts, and string (the text). - """ - - myxsl = etree.XML(''' - +from typing import List, Dict, Optional, Tuple +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from superstyl.config import Config, NormalizationConfig, SamplingConfig +from superstyl.preproc.utils import * + + +# ============================================================================ +# Constants and Configuration +# ============================================================================ + +XSLT_TEMPLATES = { + 'xml_text': ''' + + + + + + ''', - - - - - - -''') - myxsl = etree.XSLT(myxsl) - - with open(path, 'r') as f: - my_doc = etree.parse(f) - - auts = my_doc.findall("//author") - auts = [a.text for a in auts] - - if not len(auts) == 1: - print("Error: more or less than one author in" + path) - - if len(auts) == 0: - auts = [None] - - if auts == [None]: - aut = "unknown" - - else: - aut = auts[0] - - return aut, re.sub(r"\s+", " ", str(myxsl(my_doc))) - - -def txm_to_units(path, units="lines", feats="words"): - """ - Extract units from TXM file - :param path: path to TXM file - :param units: units to extract ("lines"/"verses" or "words") - :param feats: features to extract ("words", "lemma", or "pos") - :return: list of extracted units - """ - myxsl = etree.XML(''' - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -''') - myxsl = etree.XSLT(myxsl) - - with open(path, 'r') as f: - my_doc = etree.parse(f) - - units_tokens = str(myxsl(my_doc, units=etree.XSLT.strparam(units), feats=etree.XSLT.strparam(feats))).splitlines() - return units_tokens - -def tei_to_units(path, feats="words", units="lines"): - - if feats in ["met_syll", "met_line"]: - feats = "met" - myxsl = etree.XML(''' - - - - - - - - - - - - - - - - - - - - - + 'tei_units': ''' + + + + + + + + + + + + + + + + + + + + @@ -148,286 +56,419 @@ def tei_to_units(path, feats="words", units="lines"): - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + - - - - ''') - myxsl = etree.XSLT(myxsl) - - with open(path, 'r') as f: - my_doc = etree.parse(f) - - units_tokens = str(myxsl(my_doc, units=etree.XSLT.strparam(units), feats=etree.XSLT.strparam(feats))).splitlines() - return units_tokens + + + + + + + + + + + + + + + + + + + + + + + ''', + + 'txm_units': ''' + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ''' +} + + + + +class FileLoader(ABC): + """Abstract base class for file loaders.""" + + @abstractmethod + def load(self, path: str, **kwargs) -> Tuple[str, str]: + """Load file and return (author, text) tuple.""" + pass -def specialXML_to_text(path, format="tei", feats="words"): - aut = path.split('/')[-1].split("_")[0] - if format=="tei": - units_tokens = tei_to_units(path, feats=feats, units="words") - if format=="txm": - units_tokens = txm_to_units(path, feats=feats, units="words") +class TXTLoader(FileLoader): + """Loader for plain text files.""" + + def load(self, path: str, **kwargs) -> Tuple[str, str]: + with open(path, 'r') as f: + text = ' '.join(f.readlines()) + + author = extract_author_from_path(path) + return author, normalize_whitespace(text) - return aut, re.sub(r"\s+", " ", str(' '.join(units_tokens))) -def TXT_to_text(path): - """ - Get main text from xml file - :param path: path to the file to transform - :return: a tuple with auts, and string (the text). - """ +class XMLLoader(FileLoader): + """Loader for XML files.""" + + def __init__(self): + self.xslt = etree.XSLT(etree.XML(XSLT_TEMPLATES['xml_text'])) + + def load(self, path: str, **kwargs) -> Tuple[str, str]: + with open(path, 'r') as f: + doc = etree.parse(f) + + authors = doc.findall("//author") + author_texts = [a.text for a in authors] + + if len(author_texts) != 1: + print(f"Warning: Expected 1 author in {path}, found {len(author_texts)}") + author = author_texts[0] if author_texts else "unknown" + else: + author = author_texts[0] + + text = str(self.xslt(doc)) + return author, normalize_whitespace(text) - with open(path, 'r') as f: - #txt = [line.rstrip() for line in f if line.rstrip() != ''] - txt = f.readlines() - # get author from filename (string before first _) - aut = path.split('/')[-1].split("_")[0] +class XMLUnitLoader(ABC): + """Base class for XML loaders that extract units.""" + + def __init__(self, template_name: str): + self.xslt = etree.XSLT(etree.XML(XSLT_TEMPLATES[template_name])) + + def extract_units(self, path: str, units: str = "verses", + feats: str = "words") -> List[str]: + """Extract units from XML file.""" + with open(path, 'r') as f: + doc = etree.parse(f) + + params = self._get_xslt_params(units, feats) + result = str(self.xslt(doc, **params)) + return result.splitlines() + + @abstractmethod + def _get_xslt_params(self, units: str, feats: str) -> Dict: + """Get XSLT parameters for transformation.""" + pass - return aut, re.sub(r"\s+", " ", str(' '.join(txt))) +class TEIUnitLoader(XMLUnitLoader): + """Loader for TEI files with unit extraction.""" + + def __init__(self): + super().__init__('tei_units') + + def _get_xslt_params(self, units: str, feats: str) -> Dict: + feats_param = "met" if feats in ["met_syll", "met_line"] else feats + return { + 'units': etree.XSLT.strparam(units), + 'feats': etree.XSLT.strparam(feats_param) + } + + def load(self, path: str, feats: str = "words", **kwargs) -> Tuple[str, str]: + """Load TEI file and return (author, text) tuple.""" + author = extract_author_from_path(path) + units = self.extract_units(path, units="words", feats=feats) + text = normalize_whitespace(' '.join(units)) + return author, text -def detect_lang(string): - """ - Get the language from a string - :param string: a string, duh - :return: the language - """ - return langdetect.detect(string) # , k = 3) +class TXMUnitLoader(XMLUnitLoader): + """Loader for TXM files with unit extraction.""" + + def __init__(self): + super().__init__('txm_units') + + def _get_xslt_params(self, units: str, feats: str) -> Dict: + return { + 'units': etree.XSLT.strparam(units), + 'feats': etree.XSLT.strparam(feats) + } + + def load(self, path: str, feats: str = "words", **kwargs) -> Tuple[str, str]: + """Load TXM file and return (author, text) tuple.""" + author = extract_author_from_path(path) + units = self.extract_units(path, units="words", feats=feats) + text = normalize_whitespace(' '.join(units)) + return author, text -def normalise(text, keep_punct=False, keep_sym=False, no_ascii=False): - """ - Function to normalise an input string. By defaults, it removes all but word chars, remove accents, - and normalise space, and then normalise unicode. - :param keep_punct: if true, in addition, also keeps Punctuation and case distinction - :param keep_sym: if true, same as keep_punct, but keeps also N?umbers, Symbols, Marks, such as combining diacritics, - as well as Private use characters, and no Unidecode is applied - :param no_ascii: disables conversion to ascii - """ - # Remove all but word chars, remove accents, and normalise space - # and then normalise unicode +# Loader factory +LOADERS = { + 'txt': TXTLoader(), + 'xml': XMLLoader(), + 'tei': TEIUnitLoader(), + 'txm': TXMUnitLoader() +} - if keep_sym: - out = re.sub(r"[^\p{L}\p{P}\p{N}\p{S}\p{M}\p{Co}]+", " ", text) - else: - if keep_punct: - # Keep punctuation (and diacritics for now) - out = re.sub(r"[^\p{L}\p{P}\p{M}]+", " ", text) +def XML_to_text(path: str) -> Tuple[str, str]: + """Legacy function for XML loading.""" + return LOADERS['xml'].load(path) - else: - #out = re.sub(r"[\W0-9]+", " ", text.lower()) - out = re.sub(r"[^\p{L}\p{M}]+", " ", text.lower()) - if no_ascii is not True: - out = unidecode.unidecode(out) +def TXT_to_text(path: str) -> Tuple[str, str]: + """Legacy function for TXT loading.""" + return LOADERS['txt'].load(path) - # Normalise unicode - out = unicodedata.normalize("NFC", out) - out = re.sub(r"\s+", " ", out).strip() +def tei_to_units(path: str, feats: str = "words", units: str = "verses") -> List[str]: + """Legacy function for TEI unit extraction.""" + return LOADERS['tei'].extract_units(path, units, feats) - return out -def max_sampling(myTexts, max_samples=10): - """ - Select a random number of samples, equal to max_samples, for authors or classes that have more than max_samples - :param myTexts: the input myTexts object - :param max_samples: the maximum number of samples for any class - :return: a myTexts object, with the resulting selection of samples - """ - autsCounts = dict() - for text in myTexts: - if text['aut'] not in autsCounts.keys(): - autsCounts[text['aut']] = 1 +def txm_to_units(path: str, units: str = "verses", feats: str = "words") -> List[str]: + """Legacy function for TXM unit extraction.""" + return LOADERS['txm'].extract_units(path, units, feats) - else: - autsCounts[text['aut']] += 1 - for autCount in autsCounts.items(): - if autCount[1] > max_samples: - # get random selection - toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]] - toBeSelected = random.sample(toBeSelected, k=max_samples) - # Great, now remove all texts from this author from our samples - myTexts = [text for text in myTexts if text['aut'] != autCount[0]] - # and now concat - myTexts = myTexts + toBeSelected +def specialXML_to_text(path: str, format: str = "tei", feats: str = "words") -> Tuple[str, str]: + """Legacy function for special XML loading.""" + return LOADERS[format].load(path, feats=feats) - return myTexts +# ============================================================================ +# Sampling Functions +# ============================================================================ -def load_texts(paths, identify_lang=False, feats="words", format="txt", keep_punct=False, keep_sym=False, no_ascii=False, - max_samples=None): +class Sampler: """ - Loads a collection of documents into a 'myTexts' object for further processing. - TODO: a proper class - :param paths: path to docs - :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'. - :param identify_lang: whether or not try to identify lang (default: False) - :param format: format of the source files (implemented values: txt [default], xml) - :param keep_punct: whether or not to keep punctuation and caps. - :param keep_sym: whether or not to keep punctuation, caps, letter variants and numbers (no unidecode). - :param no_ascii: disables conversion to ascii - :param max_samples: the maximum number of samples for any class - :return: a myTexts object + Handles text sampling operations. """ - - myTexts = [] - - for path in paths: - name = path.split('/')[-1] - - if format=='xml': - aut, text = XML_to_text(path) - - if format in ('tei', 'txm'): - aut, text = specialXML_to_text(path, format=format, feats=feats) - + + @staticmethod + def extract_tokens(path: str, config: Config=Config()) -> List[str]: + """ + Extract tokens from a document based on format and units. + """ + feats=config.features[0].type + + if config.sampling.units == "words" and config.corpus.format == "txt": + author, text = LOADERS['txt'].load(path) + text = normalise(text, config.normalization) + return nltk.tokenize.wordpunct_tokenize(text) + + elif config.corpus.format == "tei": + return LOADERS['tei'].extract_units(path, config.corpus.units, feats) + + elif config.sampling.units == "verses" and config.corpus.format == "txm": + return LOADERS['txm'].extract_units(path, config.sampling.units, feats) + else: - aut, text = TXT_to_text(path) - - if identify_lang: - lang = detect_lang(text) + raise ValueError(f"Unsupported combination: units={config.sampling.units}, format={config.corpus.format}") + + @staticmethod + def create_samples(tokens: List[str], sampling_config: SamplingConfig=SamplingConfig()) -> List[Dict]: + """ + Create samples from tokens. + """ + step = sampling_config.step if sampling_config.step is not None else sampling_config.size + + samples = [] + + if sampling_config.random: + for k in range(sampling_config.max_samples): + samples.append({ + "start": f"{k}s", + "end": f"{k}e", + "text": list(random.choices(tokens, k=sampling_config.size)) + }) else: - lang = "NA" - - # Normalise text once and for all - text = normalise(text, keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii) - - myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang}) - - if max_samples is not None: - myTexts = max_sampling(myTexts, max_samples=max_samples) - - return myTexts - - -# Load and split in samples of length -n- a collection of files -def get_samples(path, size, step=None, samples_random=False, max_samples=10, - units="words", format="txt", feats="words", keep_punct=False, keep_sym=False, no_ascii=False): + current = 0 + while current + sampling_config.size <= len(tokens): + samples.append({ + "start": current, + "end": current + sampling_config.size, + "text": list(tokens[current:current + sampling_config.size]) + }) + current += step + + return samples + + @classmethod + def get_samples(cls, path: str, config: Config=Config()) -> List[Dict]: + """ + Extract samples from a document. + + Args: + path: Path to document + config: Config file + + Returns: + List of sample dictionaries + """ + max_samples = config.sampling.max_samples or 10 + config.sampling.validate() + + tokens = cls.extract_tokens(path, config) + return cls.create_samples(tokens, config.sampling) + + +def max_sampling(documents: List[Dict], max_samples: int = 10) -> List[Dict]: """ - Take samples of n words or verses from a document, and then parse it. - TODO: ONLY IMPLEMENTED FOR NOW: XML/TEI, TXT and verses or words as units - :param path : path to file - :param size: sample size - :param step: size of the step when sampling successively (determines overlap) default is the same - as sample size (i.e. no overlap) - :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false) - :param max_samples: maximum number of samples per author/clas - :param units: the units to use, one of "words" or "verses" - :param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED) - :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'. + Randomly select up to max_samples per author/class. + + Args: + documents: List of text dict + max_samples: Maximum samples per author + + Returns: + Filtered list of documents """ + # Count documents per author + author_counts = {} + for doc in documents: + author_counts[doc['aut']] = author_counts.get(doc['aut'], 0) + 1 + + # Filter authors with too many samples + result = [] + for author, count in author_counts.items(): + author_docs = [d for d in documents if d['aut'] == author] - if samples_random and step is not None: - raise ValueError("random sampling is not compatible with continuous sampling (remove either the step or the samples_random argument") - - if samples_random and not max_samples: - raise ValueError("random sampling needs a fixed number of samples (use the max_samples argument)") - - if step is None: - step = size - - if units == "words" and format == "txt": - my_doc = TXT_to_text(path) - text = normalise(my_doc[1], keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii) - units_tokens = nltk.tokenize.wordpunct_tokenize(text) - - #Kept only for retrocompatibility with Psysché - if units == "verses" and format == "txm": - units_tokens = txm_to_units(path, units=units) - - if format == "tei": - units_tokens = tei_to_units(path, units=units, feats=feats) + if count > max_samples: + result.extend(random.sample(author_docs, k=max_samples)) + else: + result.extend(author_docs) + + return result - # and now generating output - samples = [] - if samples_random: - for k in range(max_samples): - samples.append({"start": str(k)+'s', "end": str(k)+'e', "text": list(random.choices(units_tokens, k=size))}) +# ============================================================================ +# Main Loading Functions +# ============================================================================ - else: - current = 0 - while current + size <= len(units_tokens): - samples.append({"start": current, "end": current + size, "text": list(units_tokens[current:(current + size)])}) - current = current + step +def load_texts(paths: List[str], config: Config=Config()) -> List[Dict]: + """ + Load a collection of documents. + + Args: + paths: List of file paths + config: Config file + + Returns: + List of document dictionaries + """ + loader = LOADERS.get(config.corpus.format) + if not loader: + raise ValueError(f"Unsupported format: {config.corpus.format}") + + documents = [] + feats=config.features[0].type - return samples + for path in paths: + name = path.split('/')[-1] + author, text = loader.load(path, feats=feats) + + lang = detect_lang(text) if config.corpus.identify_lang else "NA" + + # Normalize text + text = normalise(text, config.normalization) + + documents.append({ + "name": name, + "aut": author, + "text": text, + "lang": lang + }) + + if config.sampling.max_samples is not None: + documents = max_sampling(documents, config.sampling.max_samples) + + return documents -def docs_to_samples(paths, size, step=None, units="words", samples_random=False, format="txt", feats="words", keep_punct=False, - keep_sym=False, no_ascii=False, max_samples=None, identify_lang=False): +def docs_to_samples(paths: List[str], config: Config=Config()) -> List[Dict]: """ - Loads a collection of documents into a 'myTexts' object for further processing BUT with samples ! - :param paths: path to docs - :param size: sample size - :param step: size of the step when sampling successively (determines overlap) default is the same - as sample size (i.e. no overlap) - :param units: the units to use, one of "words" or "verses" - :param samples_random: Should random sampling with replacement be performed instead of continuous sampling (default: false) - :param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED) - :param keep_punct: whether to keep punctuation and caps. - :param max_samples: maximum number of samples per author/class. - :param identify_lang: whether to try to identify lang (default: False) - :param feats: the type of features, one of 'words', 'chars', 'affixes, 'lemma', 'pos', 'met_line' and 'met_syll'. - :return: a myTexts object + Load documents with sampling. + + Args: + paths: List of file paths + config: Config file + + Returns: + List of sample dictionaries """ - myTexts = [] - for path in paths: - aut = path.split('/')[-1].split('_')[0] - if identify_lang: - if format == 'xml': - aut, text = XML_to_text(path) - - else: - aut, text = TXT_to_text(path) + loader = LOADERS.get(config.corpus.format) + if not loader: + raise ValueError(f"Unsupported format: {config.corpus.format}") + + all_samples = [] + feats=config.features[0].type + for path in paths: + author = extract_author_from_path(path) + + # Detect language if needed + if config.corpus.identify_lang: + _, text = loader.load(path, feats=feats) lang = detect_lang(text) - else: lang = 'NA' - - samples = get_samples(path, size=size, step=step, samples_random=samples_random, max_samples=max_samples, - units=units, format=format, feats=feats, - keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii) - + + # Get samples + samples = Sampler.get_samples(path, config) + + # Create sample documents for sample in samples: - name = path.split('/')[-1] + '_' + str(sample["start"]) + "-" + str(sample["end"]) - text = normalise(' '.join(sample["text"]), keep_punct=keep_punct, keep_sym=keep_sym, no_ascii=no_ascii) - myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang}) - - if max_samples is not None: - myTexts = max_sampling(myTexts, max_samples=max_samples) - - return myTexts + name = f"{path.split('/')[-1]}_{sample['start']}-{sample['end']}" + text = normalise(' '.join(sample['text']), config.normalization) + + all_samples.append({ + "name": name, + "aut": author, + "text": text, + "lang": lang + }) + + if config.sampling.max_samples is not None: + all_samples = max_sampling(all_samples, config.sampling.max_samples) + + return all_samples \ No newline at end of file diff --git a/superstyl/preproc/select.py b/superstyl/preproc/select.py index d21efb47..fb887b5a 100755 --- a/superstyl/preproc/select.py +++ b/superstyl/preproc/select.py @@ -2,151 +2,135 @@ import csv import random import json +from typing import Optional, List, Tuple -# TODO: make same modifs for the no split -def read_clean_split(path, metadata_path=None, excludes_path=None, savesplit=None, lang=None): +def _load_metadata(path: str, metadata_path: Optional[str], + excludes_path: Optional[str], lang: Optional[str]) -> Tuple[Optional[pandas.DataFrame], Optional[List[str]]]: """ - Function to read a csv, clean it, and then split it in train and dev, - either randomly or according to a preexisting selection - :param path: path to csv file - :param metadata_path: path to metadata file - :param excludes_path: path to file with list of excludes - :param presplit: path to file with preexisting split (optional) - :param savesplit: path to save split (optional) - :return: saves to disk + Load metadata and exclusion list if needed. """ - - trainf = open(path.split(".")[0] + "_train.csv", 'w') - validf = open(path.split(".")[0] + "_valid.csv", 'w') - - selection = {'train': [], 'valid': [], 'elim': []} - - # Do we need to create metadata ? + metadata = None + excludes = None + if metadata_path is None and (excludes_path is not None or lang is not None): - metadata = pandas.read_csv(path) - metadata = pandas.DataFrame(index=metadata.loc[:, "Unnamed: 0"], columns=['lang'], data=list(metadata.loc[:, "lang"])) - - if metadata_path is not None: - metadata = pandas.read_csv(metadata_path) - metadata = pandas.DataFrame(index=metadata.loc[:, "id"], columns=['lang'], data=list(metadata.loc[:, "true"])) - + data = pandas.read_csv(path) + metadata = pandas.DataFrame( + index=data.loc[:, "Unnamed: 0"], + columns=['lang'], + data=list(data.loc[:, "lang"]) + ) + elif metadata_path is not None: + data = pandas.read_csv(metadata_path) + metadata = pandas.DataFrame( + index=data.loc[:, "id"], + columns=['lang'], + data=list(data.loc[:, "true"]) + ) + if excludes_path is not None: - excludes = pandas.read_csv(excludes_path) - excludes = list(excludes.iloc[:, 0]) - - with open(path, "r") as f: - head = f.readline() - trainf.write(head) - validf.write(head) - - # and prepare to write csv lines to them - train = csv.writer(trainf) - valid = csv.writer(validf) - - print("....evaluating each text.....") - - reader = csv.reader(f, delimiter=",") - - for line in reader: - - # checks - if lang is not None: - # First check if good language - if not metadata.loc[line[0], "lang"] == lang: - selection['elim'].append(line[0]) - print("not in: " + lang + " " + line[0]) - # if not, eliminate it, and go to next line - continue - - if excludes_path is not None: - # then check if to exclude - if line[0] in excludes: - selection['elim'].append(line[0]) - print("Is a Wilhelmus instance! : " + line[0]) - # then eliminate it, and go to next line - continue - - # Now that we have only the good lines, proceed to split - - # 10% for dev - if random.randint(1, 10) == 1: - selection['valid'].append(line[0]) - valid.writerow(line) - - # 90% for train - else: - selection['train'].append(line[0]) - train.writerow(line) + excludes_data = pandas.read_csv(excludes_path) + excludes = list(excludes_data.iloc[:, 0]) + + return metadata, excludes - trainf.close() - validf.close() - with open(savesplit, "w") as out: - out.write(json.dumps(selection)) -# TODO: merge this one and the previous ? -def read_clean(path, metadata_path=None, excludes_path=None, savesplit=None, lang=None): +def _should_exclude(line_id: str, metadata: Optional[pandas.DataFrame], + excludes: Optional[List[str]], lang: Optional[str]) -> Tuple[bool, Optional[str]]: """ - Function to read a csv, clean it. - :param path: path to csv file - :param metadata_path: path to metadata file - :param excludes_path: path to file with list of excludes + Determine if a line should be excluded. + """ + if lang is not None and metadata is not None: + try: + if metadata.loc[line_id, "lang"] != lang: + return True, f"not in: {lang} {line_id}" + except KeyError: + pass + + if excludes is not None and line_id in excludes: + return True, f"Is a Wilhelmus instance! : {line_id}" + + return False, None + + +def read_clean(path: str, metadata_path: Optional[str] = None, + excludes_path: Optional[str] = None, + savesplit: Optional[str] = None, + lang: Optional[str] = None, + split: bool = False, + split_ratio: float = 0.1) -> None: + """ + Read a CSV, clean it, and optionally split it into train and validation sets. + + :param path: path to CSV file + :param metadata_path: path to metadata file (optional) + :param excludes_path: path to file with list of excludes (optional) + :param savesplit: path to save selection JSON (optional) + :param lang: only include texts in this language (optional) + :param split: if True, split into train/valid sets (default False) + :param split_ratio: ratio for validation set when split=True (default 0.1 = 10%) :return: saves to disk """ - - trainf = open(path.split(".")[0] + "_selected.csv", 'w') - + metadata, excludes = _load_metadata(path, metadata_path, excludes_path, lang) + + base_path = path.rsplit(".", 1)[0] + + # Initialize selection tracking selection = {'train': [], 'elim': []} - - # Do we need to create metadata ? - if metadata_path is None and (excludes_path is not None or lang is not None): - metadata = pandas.read_csv(path) - metadata = pandas.DataFrame(index=metadata.loc[:, "Unnamed: 0"], columns=['lang'], data=list(metadata.loc[:, "lang"])) - - if metadata_path is not None: - metadata = pandas.read_csv(metadata_path) - metadata = pandas.DataFrame(index=metadata.loc[:, "id"], columns=['lang'], data=list(metadata.loc[:, "true"])) - - if excludes_path is not None: - excludes = pandas.read_csv(excludes_path) - excludes = list(excludes.iloc[:, 0]) - - with open(path, "r") as f: - head = f.readline() - trainf.write(head) - - # and prepare to write csv lines to them - train = csv.writer(trainf) - - print("....evaluating each text.....") - - reader = csv.reader(f, delimiter=",") - - for line in reader: - - # checks - if lang is not None: - # First check if good language - if not metadata.loc[line[0], "lang"] == lang: - selection['elim'].append(line[0]) - print("not in: " + lang + " " + line[0]) - # if not, eliminate it, and go to next line - continue - - if excludes_path is not None: - # then check if to exclude - if line[0] in excludes: - selection['elim'].append(line[0]) - print("Is a Wilhelmus instance! : " + line[0]) - # then eliminate it, and go to next line + if split: + selection['valid'] = [] + + # Open output files + if split: + train_path = f"{base_path}_train.csv" + valid_path = f"{base_path}_valid.csv" + trainf = open(train_path, 'w') + validf = open(valid_path, 'w') + else: + train_path = f"{base_path}_selected.csv" + trainf = open(train_path, 'w') + validf = None + + try: + with open(path, "r") as f: + header = f.readline() + trainf.write(header) + if validf: + validf.write(header) + + train_writer = csv.writer(trainf) + valid_writer = csv.writer(validf) if validf else None + + print("....evaluating each text.....") + reader = csv.reader(f, delimiter=",") + + for line in reader: + line_id = line[0] + + # Check exclusion + is_excluded, reason = _should_exclude(line_id, metadata, excludes, lang) + if is_excluded: + selection['elim'].append(line_id) + if reason: + print(reason) continue - - # Now that we have only the good lines, proceed to write - selection['train'].append(line[0]) - train.writerow(line) - - trainf.close() - with open(savesplit, "w") as out: - out.write(json.dumps(selection)) + + # Route to train or valid + if split and random.random() < split_ratio: + selection['valid'].append(line_id) + valid_writer.writerow(line) + else: + selection['train'].append(line_id) + train_writer.writerow(line) + + finally: + trainf.close() + if validf: + validf.close() + + # Save selection + if savesplit: + with open(savesplit, "w") as out: + out.write(json.dumps(selection)) def apply_selection(path, presplit_path): """ diff --git a/superstyl/preproc/utils.py b/superstyl/preproc/utils.py new file mode 100644 index 00000000..5de411d1 --- /dev/null +++ b/superstyl/preproc/utils.py @@ -0,0 +1,50 @@ +import unidecode +import langdetect +import unicodedata +import regex as re + +from superstyl.config import NormalizationConfig + + + +def extract_author_from_path(path: str) -> str: + """Extract author from file path (before first underscore).""" + return path.split('/')[-1].split("_")[0] + + +def normalize_whitespace(text: str) -> str: + """Normalize whitespace in text.""" + return re.sub(r"\s+", " ", text).strip() + + +def detect_lang(text: str) -> str: + """Detect language of text using langdetect.""" + return langdetect.detect(text) + + +def normalise(text: str, norm_config : NormalizationConfig=NormalizationConfig()) -> str: + """ + Normalize input text according to specified options. + + Args: + text: Input text to normalize + keep_punct: Keep punctuation and case distinction + keep_sym: Keep punctuation, case, numbers, symbols, marks + no_ascii: Disable conversion to ASCII + + Returns: + Normalized text + """ + if norm_config.keep_sym: + out = re.sub(r"[^\p{L}\p{P}\p{N}\p{S}\p{M}\p{Co}]+", " ", text) + else: + if norm_config.keep_punct: + out = re.sub(r"[^\p{L}\p{P}\p{M}]+", " ", text) + else: + out = re.sub(r"[^\p{L}\p{M}]+", " ", text.lower()) + + if not norm_config.no_ascii: + out = unidecode.unidecode(out) + + out = unicodedata.normalize("NFC", out) + return normalize_whitespace(out) \ No newline at end of file diff --git a/superstyl/svm.py b/superstyl/svm.py index abceb823..74ce13d5 100755 --- a/superstyl/svm.py +++ b/superstyl/svm.py @@ -12,45 +12,53 @@ import imblearn.combine as comb import imblearn.pipeline as imbp from collections import Counter +from typing import Optional, Dict, Any +from superstyl.config import Config -def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, balance=None, class_weights=False, - kernel="LinearSVC", - final_pred=False, get_coefs=False): + +def train_svm( + train: pandas.DataFrame, + test: Optional[pandas.DataFrame] = None, + config: Optional[Config] = None, + **kwargs +) -> Dict[str, Any]: """ - Function to train svm - :param train: train data... (in panda dataframe) - :param test: test data (itou) - :param cross_validate: whether to perform cross validation (possible values: leave-one-out, k-fold - and group-k-fold) if group_k-fold is chosen, each source file will be considered a group, so this is only relevant - if sampling was performed and more than one file per class was provided - :param k: k parameter for k-fold cross validation - :param dim_reduc: dimensionality reduction of input data. Implemented values are pca and som. - :param norms: perform normalisations, i.e. z-scores and L2 (default True) - :param balance: up/downsampling strategy to use in imbalanced datasets - :param class_weights: adjust class weights to balance imbalanced datasets, with weights inversely proportional to class - frequencies in the input data as n_samples / (n_classes * np.bincount(y)) - :param kernel: kernel for SVM - :param final_pred: do the final predictions? - :param get_coefs, if true, writes to disk (coefficients.csv) and plots the most important coefficients for each class - :return: prints the scores, and then returns a dictionary containing the pipeline with a fitted svm model, - and, if computed, the classification_report, confusion_matrix, list of misattributions, and final_predictions. + Train SVM model for stylometric analysis. + + Can be called with: + 1. A Config object: train_svm(train, test, config=my_config) + 2. Individual parameters (backward compatible): + train_svm(train, test, cross_validate="k-fold", k=10) + + Args: + train: Training data (pandas DataFrame) + test: Test data (optional) + config: Configuration object. If None, built from kwargs. + **kwargs: Individual parameters for backward compatibility. + Supported: cross_validate, k, dim_reduc, norms, balance, + class_weights, kernel, final_pred, get_coefs + + Returns: + Dictionary containing: pipeline, and optionally confusion_matrix, + classification_report, misattributions, final_predictions, coefficients """ - results = {} + # Build config from kwargs if not provided + if config is None: + config = Config.from_kwargs(**kwargs) + + # Extract SVM parameters from config + cross_validate = config.svm.cross_validate + k = config.svm.k + dim_reduc = config.svm.dim_reduc + norms = config.svm.norms + balance = config.svm.balance + class_weights = config.svm.class_weights + kernel = config.svm.kernel + final_pred = config.svm.final_pred + get_coefs = config.svm.get_coefs - valid_cross_validate_options = {None, "leave-one-out", "k-fold", 'group-k-fold'} - valid_dim_reduc_options = {None, 'pca'} - valid_balance_options = {None, 'downsampling', 'upsampling', 'Tomek', 'SMOTE', 'SMOTETomek'} - # Validate parameters - if cross_validate not in valid_cross_validate_options: - raise ValueError( - f"Invalid cross-validation option: '{cross_validate}'. Valid options are {valid_cross_validate_options}.") - if dim_reduc not in valid_dim_reduc_options: - raise ValueError( - f"Invalid dimensionality reduction option: '{dim_reduc}'. Valid options are {valid_dim_reduc_options}.") - # Validate 'balance' parameter - if balance not in valid_balance_options: - raise ValueError(f"Invalid balance option: '{balance}'. Valid options are {valid_balance_options}.") + results = {} print(".......... Formatting data ........") # Save the classes @@ -72,28 +80,19 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, if dim_reduc == 'pca': print(".......... using PCA ........") - estimators.append(('dim_reduc', decomp.PCA())) # chosen with default - # which is: n_components = min(n_samples, n_features) + estimators.append(('dim_reduc', decomp.PCA())) + if norms: - # Z-scores print(".......... using normalisations ........") estimators.append(('scaler', preproc.StandardScaler())) - # NB: j'utilise le built-in - # normalisation L2 - # cf. https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Normalizer.html#sklearn.preprocessing.Normalizer estimators.append(('normalizer', preproc.Normalizer())) if balance is not None: - print(".......... implementing strategy to solve imbalance in data ........") if balance == 'downsampling': estimators.append(('sampling', under.RandomUnderSampler(random_state=42, replacement=False))) - # if balance == 'ENN': - # enn = under.EditedNearestNeighbours() - # train, classes = enn.fit_resample(train, classes) - if balance == 'Tomek': estimators.append(('sampling', under.TomekLinks())) @@ -109,42 +108,35 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, # In case we have to temper with the n_neighbors, we print a warning message to the user # (might be written more clearly, but we want a short message, right?) if 0 < n_neighbors >= min_class_size: - print( - f"Warning: Adjusting n_neighbors for SMOTE to {n_neighbors} due to small class size.") + print(f"Warning: Adjusting n_neighbors for SMOTE to {n_neighbors} due to small class size.") if n_neighbors == 0: - print( - f"Warning: at least one class only has a single individual; cannot apply SMOTE(Tomek) due to small class size.") + print("Warning: at least one class only has a single individual; cannot apply SMOTE(Tomek).") else: - if balance == 'SMOTE': estimators.append(('sampling', over.SMOTE(k_neighbors=n_neighbors, random_state=42))) - elif balance == 'SMOTETomek': - estimators.append(('sampling', comb.SMOTETomek(random_state=42, smote=over.SMOTE(k_neighbors=n_neighbors, random_state=42)))) + estimators.append(('sampling', comb.SMOTETomek( + random_state=42, + smote=over.SMOTE(k_neighbors=n_neighbors, random_state=42) + ))) print(".......... choosing SVM ........") if kernel == "LinearSVC": - # try a faster one estimators.append(('model', sk.LinearSVC(class_weight=cw, dual="auto"))) - # classif = sk.LinearSVC() - else: estimators.append(('model', sk.SVC(kernel=kernel, class_weight=cw))) - # classif = sk.SVC(kernel=kernel) print(".......... Creating pipeline with steps ........") print(estimators) if 'sampling' in [k[0] for k in estimators]: pipe = imbp.Pipeline(estimators) - else: pipe = skp.Pipeline(estimators) - # Now, doing leave one out validation or training single SVM with train / test split - + # Cross validation or train/test split if cross_validate is not None: works = None if cross_validate == 'leave-one-out': @@ -152,34 +144,28 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, if cross_validate == 'k-fold': if k == 0: - k = 10 # set default - + k = 10 myCV = skmodel.KFold(n_splits=k) if cross_validate == 'group-k-fold': - # Get the groups as the different source texts works = ["_".join(t.split("_")[:-1]) for t in train.index.values] if k == 0: k = len(set(works)) - myCV = skmodel.GroupKFold(n_splits=k) - print(".......... " + cross_validate + " cross validation will be performed ........") - print(".......... using " + str(myCV.get_n_splits(train)) + " samples or groups........") - - # Will need to - # 1. train a model - # 2. get prediction - # 3. compute score: precision, recall, F1 for all categories + print(f".......... {cross_validate} cross validation will be performed ........") + print(f".......... using {myCV.get_n_splits(train)} samples or groups........") preds = skmodel.cross_val_predict(pipe, train, classes, cv=myCV, verbose=1, n_jobs=-1, groups=works) # and now, leave one out evaluation (very small redundancy here, one line that could be stored elsewhere) unique_labels = list(set(classes)) - results["confusion_matrix"] = pandas.DataFrame(metrics.confusion_matrix(classes, preds, labels=unique_labels), - index=['true:{:}'.format(x) for x in unique_labels], - columns=['pred:{:}'.format(x) for x in unique_labels]) + results["confusion_matrix"] = pandas.DataFrame( + metrics.confusion_matrix(classes, preds, labels=unique_labels), + index=[f'true:{x}' for x in unique_labels], + columns=[f'pred:{x}' for x in unique_labels] + ) report = metrics.classification_report(classes, preds) print(report) @@ -203,14 +189,16 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, else: pipe.fit(train, classes) preds = pipe.predict(test) + if not final_pred: # and evaluate unique_labels = list(set(classes + classes_test)) results["confusion_matrix"] = pandas.DataFrame( metrics.confusion_matrix(classes_test, preds, labels=unique_labels), - index=['true:{:}'.format(x) for x in unique_labels], - columns=['pred:{:}'.format(x) for x in unique_labels]) + index=[f'true:{x}' for x in unique_labels], + columns=[f'pred:{x}' for x in unique_labels] + ) report = metrics.classification_report(classes, preds) print(report) @@ -222,45 +210,50 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, columns=["id", "True", "Pred"] ).set_index('id') - # AND NOW, we need to evaluate or create the final predictions + # Final predictions if final_pred: # Get the decision function too myclasses = pipe.classes_ decs = pipe.decision_function(test) dists = {} + if len(pipe.classes_) == 2: results["final_predictions"] = pandas.DataFrame( - data={**{'filename': preds_index, 'author': list(preds)}, 'Decision function': decs}) - + data={**{'filename': preds_index, 'author': list(preds)}, 'Decision function': decs} + ) else: for myclass in enumerate(myclasses): dists[myclass[1]] = [d[myclass[0]] for d in decs] - results["final_predictions"] = pandas.DataFrame( - data={**{'filename': preds_index, 'author': list(preds)}, **dists}) + results["final_predictions"] = pandas.DataFrame( + data={**{'filename': preds_index, 'author': list(preds)}, **dists} + ) if get_coefs: if kernel != "LinearSVC": print(".......... COEFS ARE ONLY IMPLEMENTED FOR linearSVC ........") - else: # For “one-vs-rest” LinearSVC the attributes coef_ and intercept_ have the shape (n_classes, n_features) and # (n_classes,) respectively. # Each row of the coefficients corresponds to one of the n_classes “one-vs-rest” classifiers and similar for the # intercepts, in the order of the “one” class. if len(pipe.classes_) == 2: - results["coefficients"] = pandas.DataFrame(pipe.named_steps['model'].coef_, - index=[pipe.classes_[0]], - columns=train.columns) - - plot_coefficients(pipe.named_steps['model'].coef_[0], train.columns, - pipe.classes_[0] + " versus " + pipe.classes_[1]) - + results["coefficients"] = pandas.DataFrame( + pipe.named_steps['model'].coef_, + index=[pipe.classes_[0]], + columns=train.columns + ) + plot_coefficients( + pipe.named_steps['model'].coef_[0], + train.columns, + f"{pipe.classes_[0]} versus {pipe.classes_[1]}" + ) else: - results["coefficients"] = pandas.DataFrame(pipe.named_steps['model'].coef_, - index=pipe.classes_, - columns=train.columns) - + results["coefficients"] = pandas.DataFrame( + pipe.named_steps['model'].coef_, + index=pipe.classes_, + columns=train.columns + ) for i in range(len(pipe.classes_)): plot_coefficients(pipe.named_steps['model'].coef_[i], train.columns, pipe.classes_[i]) @@ -272,23 +265,28 @@ def train_svm(train, test, cross_validate=None, k=0, dim_reduc=None, norms=True, # https://aneesha.medium.com/visualising-top-features-in-linear-svm-with-scikit-learn-and-matplotlib-3454ab18a14d def plot_coefficients(coefs, feature_names, current_class, top_features=10): - plt.rcParams.update({'font.size': 30}) # increase font size + """Plot the most important coefficients for a class.""" + plt.rcParams.update({'font.size': 30}) top_positive_coefficients = np.argsort(coefs)[-top_features:] top_negative_coefficients = np.argsort(coefs)[:top_features] top_coefficients = np.hstack([top_negative_coefficients, top_positive_coefficients]) - # create plot + plt.figure(figsize=(15, 8)) colors = ['red' if c < 0 else 'blue' for c in coefs[top_coefficients]] plt.bar(np.arange(2 * top_features), coefs[top_coefficients], color=colors) feature_names = np.array(feature_names) - plt.xticks(np.arange(0, 2 * top_features), feature_names[top_coefficients], rotation=60, ha='right', - rotation_mode='anchor') - plt.title("Coefficients for " + current_class) - plt.savefig('coefs_' + current_class + '.png', bbox_inches='tight') - - - -def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)"): + plt.xticks( + np.arange(0, 2 * top_features), + feature_names[top_coefficients], + rotation=60, + ha='right', + rotation_mode='anchor' + ) + plt.title(f"Coefficients for {current_class}") + plt.savefig(f'coefs_{current_class}.png', bbox_inches='tight') + + +def plot_rolling(final_predictions, smoothing=3, xlab="Index (segment center)"): """ Plots the rolling stylometry results as lines of decision function values over the text. @@ -306,6 +304,7 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)" # Extract the segment center from the filename my_final_predictions = final_predictions.copy() # to avoid modifying in place segment_centers = [] + for fname in my_final_predictions['filename']: parts = fname.split('_')[-1].split('-') start = int(parts[0]) @@ -314,7 +313,6 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)" segment_centers.append(center) my_final_predictions['segment_center'] = segment_centers - my_final_predictions['filename'] = [fname.split('_')[1] for fname in my_final_predictions['filename']] # Identify candidate columns @@ -329,21 +327,23 @@ def plot_rolling(final_predictions, smoothing=3, xlab = "Index (segment center)" # Apply smoothing if requested if smoothing and smoothing > 0: for col in candidate_cols: - fpreds_work[col] = fpreds_work[col].rolling(window=smoothing, center=True, min_periods=1).mean() + fpreds_work[col] = fpreds_work[col].rolling( + window=smoothing, center=True, min_periods=1 + ).mean() # Plotting plt.figure(figsize=(24, 12)) for col in candidate_cols: plt.plot(fpreds_work['segment_center'], fpreds_work[col], label=col, linewidth=2) - plt.title('Rolling Stylometry Decision Functions Over ' + work) + plt.title(f'Rolling Stylometry Decision Functions Over {work}') plt.xlabel(xlab) plt.ylabel('Decision Function Value') - plt.ylim(min(-2, min(fpreds_work[candidate_cols].min()) - 0.2), - max(1, max(fpreds_work[candidate_cols].max())) + 0.2) + plt.ylim( + min(-2, min(fpreds_work[candidate_cols].min()) - 0.2), + max(1, max(fpreds_work[candidate_cols].max())) + 0.2 + ) plt.legend(title='Candidate Authors', fontsize="small") plt.grid(True) plt.tight_layout() - #plt.show() - plt.savefig('rolling_'+ work + '.png', bbox_inches='tight') - + plt.savefig(f'rolling_{work}.png', bbox_inches='tight') \ No newline at end of file diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index cc37116a..1756c0b3 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -1,7 +1,8 @@ import unittest import superstyl.load import superstyl.preproc.features_extract -from superstyl.load_from_config import load_corpus_from_config +from superstyl.load import load_corpus +from superstyl.config import Config import os import tempfile import json @@ -33,7 +34,7 @@ def test_load_corpus_lemma_requires_tei(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - self.test_paths, + data_paths=self.test_paths, feats="lemma", format="txt" ) @@ -48,7 +49,7 @@ def test_load_corpus_pos_requires_tei(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - self.test_paths, + data_paths=self.test_paths, feats="pos", format="txt" ) @@ -63,7 +64,7 @@ def test_load_corpus_met_line_requires_tei(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - self.test_paths, + data_paths=self.test_paths, feats="met_line", format="txt" ) @@ -78,7 +79,7 @@ def test_load_corpus_met_syll_requires_tei(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - self.test_paths, + data_paths=self.test_paths, feats="met_syll", format="txt" ) @@ -86,8 +87,8 @@ def test_load_corpus_met_syll_requires_tei(self): self.assertIn("met_syll", str(context.exception)) self.assertIn("tei", str(context.exception).lower()) - def test_load_corpus_met_line_requires_lines_unit(self): - # SCENARIO: met_line requires units='lines' + def test_load_corpus_met_line_requires_verses_unit(self): + # SCENARIO: met_line requires units='verses' # GIVEN: Attempting to use met_line with units='words' # Create a dummy TEI file for this test @@ -98,17 +99,18 @@ def test_load_corpus_met_line_requires_lines_unit(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - [tei_path], + data_paths=[tei_path], feats="met_line", format="tei", + sampling=True, units="words" # Wrong unit type ) self.assertIn("met_line", str(context.exception)) - self.assertIn("lines", str(context.exception)) + self.assertIn("verses", str(context.exception)) - def test_load_corpus_met_syll_requires_lines_unit(self): - # SCENARIO: met_syll requires units='lines' + def test_load_corpus_met_syll_requires_verses_unit(self): + # SCENARIO: met_syll requires units='verses' # GIVEN: Attempting to use met_syll with units='words' # Create a dummy TEI file for this test @@ -119,14 +121,15 @@ def test_load_corpus_met_syll_requires_lines_unit(self): # WHEN/THEN: Should raise ValueError with self.assertRaises(ValueError) as context: superstyl.load.load_corpus( - [tei_path], + data_paths=[tei_path], feats="met_syll", format="tei", + sampling=True, units="words" # Wrong unit type ) self.assertIn("met_syll", str(context.exception)) - self.assertIn("lines", str(context.exception)) + self.assertIn("verses", str(context.exception)) # ========================================================================= # Tests pour features_extract.py - ValueError pour paramètres invalides @@ -248,14 +251,16 @@ def test_load_from_config_with_json_feature_list(self): # Create config config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt" + }, "features": [ { "name": "test_feature", "type": "words", "n": 1, - "feat_list": feature_list_path # JSON feature list + "feat_list_path": feature_list_path # JSON feature list path } ] } @@ -265,7 +270,8 @@ def test_load_from_config_with_json_feature_list(self): json.dump(config, f) # WHEN: Loading corpus from config - corpus, features = load_corpus_from_config(config_path) + config = Config.from_json(config_path) + corpus, features = load_corpus(config=config) # THEN: Should load successfully with JSON feature list self.assertIsNotNone(corpus) @@ -283,20 +289,22 @@ def test_load_from_config_test_mode_uses_feat_list(self): # Create config with multiple features (triggers is_test logic) config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt" + }, "features": [ { "name": "feat1", "type": "words", "n": 1, - "feat_list": feature_list_path + "feat_list_path": feature_list_path }, { "name": "feat2", "type": "chars", "n": 2, - "feat_list": feature_list_path + "feat_list_path": feature_list_path } ] } @@ -306,7 +314,8 @@ def test_load_from_config_test_mode_uses_feat_list(self): json.dump(config, f) # WHEN: Loading corpus from config - corpus, features = load_corpus_from_config(config_path, is_test=True) + config_obj = Config.from_json(config_path) + corpus, features = load_corpus(config=config_obj, use_provided_feat_list=True) # THEN: Should use the provided feature list self.assertIsNotNone(corpus) diff --git a/tests/test_load_corpus.py b/tests/test_load_corpus.py index 076aa199..bea27e35 100644 --- a/tests/test_load_corpus.py +++ b/tests/test_load_corpus.py @@ -6,6 +6,7 @@ import superstyl.preproc.embedding import superstyl.preproc.select import superstyl.preproc.text_count +from superstyl.config import NormalizationConfig, Config import os import glob @@ -19,7 +20,7 @@ class Main(unittest.TestCase): def test_load_corpus(self): # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths) + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths) # THEN expected_feats = [('this', 2/12), ('is', 2/12), ('the', 2/12), ('text', 2/12), ('voici', 1/12), ('le', 1/12), ('texte', 1/12), ('also', 1/12)] @@ -37,7 +38,7 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, culling=50) + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, culling=50) # THEN expected_feats = [('this', 2 / 12), ('is', 2 / 12), ('the', 2 / 12), ('text', 2 / 12)] expected_corpus = { @@ -51,7 +52,7 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feat_list=[('the', 0)], feats="chars", n=3, k=5000, freqsType="absolute", + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feat_list=[('the', 0)], feats="chars", n=3, k=5000, freqsType="absolute", format="txt", keep_punct=False, keep_sym=False, identify_lang=True) # THEN @@ -65,7 +66,7 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feats="words", n=1, + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="words", n=1, sampling=True, units="words", size=2, step=None, keep_punct=True, keep_sym=False) @@ -127,13 +128,13 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, k=4) + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, k=4) # THEN expected_feats = [('this', 2 / 12), ('is', 2 / 12), ('the', 2 / 12), ('text', 2 / 12)] self.assertEqual(feats, expected_feats) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feats="chars", n=3, format="txt", keep_punct=True, + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="chars", n=3, format="txt", keep_punct=True, freqsType="absolute") # THEN @@ -186,7 +187,7 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feats="chars", n=3, format="txt", keep_punct=True, + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="chars", n=3, format="txt", keep_punct=True, freqsType="binary") # THEN @@ -242,7 +243,7 @@ def test_load_corpus(self): self.assertEqual(corpus.to_dict(), expected_corpus) # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feats="affixes", n=3, format="txt", keep_punct=True) + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="affixes", n=3, format="txt", keep_punct=True) # THEN expected_feats = [('_te', 3/51), ('tex', 3/51), ('ext', 2/51), ('is_', 3/51), ('Thi', 2/51), ('his', 2/51), @@ -290,7 +291,7 @@ def test_load_corpus(self): # Now, test embedding # WHEN - corpus, feats = superstyl.load.load_corpus(self.paths, feats="words", n=1, format="txt", + corpus, feats = superstyl.load.load_corpus(data_paths=self.paths, feats="words", n=1, format="txt", embedding=THIS_DIR+"/embed/test_embedding.wv.txt", neighbouring_size=1) # THEN @@ -312,8 +313,9 @@ def test_load_texts_txt(self): # SCENARIO: from paths to txt, get myTexts object, i.e., a list of dictionaries # # for each text or samples, with metadata and the text itself # WHEN - results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=False, - keep_sym=False, max_samples=None) + config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=False, + keep_sym=False, max_samples=None) + results = superstyl.preproc.pipe.load_texts(self.paths, config) # THEN expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'voici le texte', 'lang': 'NA'}, {'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'this is the text', 'lang': 'NA'}, @@ -323,14 +325,16 @@ def test_load_texts_txt(self): self.assertEqual(results, expected) # WHEN - results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=False, - keep_sym=False, max_samples=1) + config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=False, + keep_sym=False, max_samples=1) + results = superstyl.preproc.pipe.load_texts(self.paths, config) # THEN self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 1) # WHEN - results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", keep_punct=True, - keep_sym=False, max_samples=None) + config = Config.from_kwargs(identify_lang=False, format="txt", keep_punct=True, + keep_sym=False, max_samples=None) + results = superstyl.preproc.pipe.load_texts(self.paths, config) # THEN expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'Voici le texte!', 'lang': 'NA'}, {'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'This is the text!', 'lang': 'NA'}, @@ -339,8 +343,9 @@ def test_load_texts_txt(self): self.assertEqual(results, expected) # WHEN - results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=False, format="txt", - keep_sym=True, max_samples=None) + config = Config.from_kwargs(identify_lang=False, format="txt", + keep_sym=True, max_samples=None) + results = superstyl.preproc.pipe.load_texts(self.paths, config) # THEN expected = [{'name': 'Dupont_Letter1.txt', 'aut': 'Dupont', 'text': 'Voici le texte!', 'lang': 'NA'}, {'name': 'Smith_Letter1.txt', 'aut': 'Smith', 'text': 'This is the text!', 'lang': 'NA'}, @@ -349,8 +354,9 @@ def test_load_texts_txt(self): self.assertEqual(results, expected) # WHEN - results = superstyl.preproc.pipe.load_texts(self.paths, identify_lang=True, format="txt", keep_punct=True, - keep_sym=False, max_samples=None) + config = Config.from_kwargs(identify_lang=True, format="txt", keep_punct=True, + keep_sym=False, max_samples=None) + results = superstyl.preproc.pipe.load_texts(self.paths, config) # THEN # Just testing that a lang is predicted, not if it is ok or not self.assertEqual(len([text for text in results if text["lang"] != 'NA']), 3) @@ -359,8 +365,9 @@ def test_load_texts_txt(self): def test_docs_to_samples(self): # WHEN - results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None, units="words", - format="txt", keep_punct=False, keep_sym=False, max_samples=None) + config = Config.from_kwargs(identify_lang=False, size=2, step=None, units="words", + format="txt", keep_punct=False, keep_sym=False, max_samples=None) + results = superstyl.preproc.pipe.docs_to_samples(self.paths, config) # THEN expected = [{'name': 'Dupont_Letter1.txt_0-2', 'aut': 'Dupont', 'text': 'voici le', 'lang': 'NA'}, {'name': 'Smith_Letter1.txt_0-2', 'aut': 'Smith', 'text': 'this is', 'lang': 'NA'}, @@ -370,10 +377,11 @@ def test_docs_to_samples(self): self.assertEqual(results, expected) # WHEN - results = superstyl.preproc.pipe.docs_to_samples(sorted(self.paths), identify_lang=False, size=2, step=1, - units="words", format="txt", keep_punct=True, - keep_sym=True, - max_samples=None) + config = Config.from_kwargs(identify_lang=False, size=2, step=1, + units="words", format="txt", keep_punct=True, + keep_sym=True, + max_samples=None) + results = superstyl.preproc.pipe.docs_to_samples(sorted(self.paths), config) # THEN expected = [{'name': 'Dupont_Letter1.txt_0-2', 'aut': 'Dupont', 'text': 'Voici le', 'lang': 'NA'}, @@ -396,37 +404,43 @@ def test_docs_to_samples(self): self.assertEqual(results, expected) # WHEN - results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=True, size=2, step=None, - units="words", format="txt", keep_punct=False, - keep_sym=False, - max_samples=None) + config = Config.from_kwargs(identify_lang=True, size=2, step=None, + units="words", format="txt", keep_punct=False, + keep_sym=False, + max_samples=None) + results = superstyl.preproc.pipe.docs_to_samples(self.paths, config) # THEN self.assertEqual(len([text for text in results if text["lang"] != 'NA']), 5) # WHEN - results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None, - units="words", format="txt", keep_punct=False, - keep_sym=False, - max_samples=1) + config = Config.from_kwargs(identify_lang=False, size=2, step=None, + units="words", format="txt", keep_punct=False, + keep_sym=False, + max_samples=1) + results = superstyl.preproc.pipe.docs_to_samples(self.paths, config) # THEN self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 1) # TODO: this is just minimal testing for random sampling # WHEN - results = superstyl.preproc.pipe.docs_to_samples(self.paths, identify_lang=False, size=2, step=None, - units="words", - format="txt", keep_punct=False, keep_sym=False, - max_samples=5, samples_random=True) + config = Config.from_kwargs(identify_lang=False, size=2, step=None, + units="words", + format="txt", keep_punct=False, keep_sym=False, + max_samples=5, samples_random=True) + results = superstyl.preproc.pipe.docs_to_samples(self.paths, config) # THEN self.assertEqual(len([text for text in results if text["aut"] == 'Smith']), 5) # and now tests that error are raised when parameters combinations are not consistent # WHEN/THEN - self.assertRaises(ValueError, superstyl.preproc.pipe.docs_to_samples, self.paths, size=2, step=1, units="words", - format="txt", max_samples=5, samples_random=True) - self.assertRaises(ValueError, superstyl.preproc.pipe.docs_to_samples, self.paths, size=2, units="words", - format="txt", max_samples=None, - samples_random=True) + with self.assertRaises(ValueError): + Config.from_kwargs(size=2, step=1, units="words", + format="txt", max_samples=5, samples_random=True) + + with self.assertRaises(ValueError): + Config.from_kwargs(size=2, units="words", + format="txt", max_samples=None, + samples_random=True) # TODO: test other loading formats with sampling, that are not txt (and decide on their implementation) @@ -553,11 +567,6 @@ def test_get_counts(self): self.assertEqual(results, expected) - # TODO: test count_process - - # TODO: test features_select - # TODO: test select - class DataLoading(unittest.TestCase): @@ -569,27 +578,32 @@ def test_normalise(self): # SCENARIO # GIVEN text = " Hello, Mr. 𓀁, how are §§ you; doing? ſõ ❡" + norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = False) # WHEN - results = superstyl.preproc.pipe.normalise(text) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_default = "hello mr how are you doing s o" self.assertEqual(results, expected_default) # WHEN - results = superstyl.preproc.pipe.normalise(text, no_ascii=True) + norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = True) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_default = "hello mr 𓀁 how are you doing ſ õ" self.assertEqual(results, expected_default) # WHEN - results = superstyl.preproc.pipe.normalise(text, keep_punct=True) + norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = False) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_keeppunct = "Hello, Mr. , how are SSSS you; doing? s o" # WHEN - results = superstyl.preproc.pipe.normalise(text, keep_punct=True, no_ascii=True) + norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = True) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_keeppunct = "Hello, Mr. 𓀁, how are §§ you; doing? ſ õ" self.assertEqual(results, expected_keeppunct) # WHEN - results = superstyl.preproc.pipe.normalise(text, keep_sym=True) + norm_conf = NormalizationConfig(keep_punct = False, keep_sym = True, no_ascii = False) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_keepsym = "Hello, Mr. 𓀁, how are §§ you; doing? ſ\uf217õ ❡" self.assertEqual(results, expected_keepsym) @@ -598,7 +612,8 @@ def test_normalise(self): # GIVEN text = 'Coucou 😅' # WHEN - results = superstyl.preproc.pipe.normalise(text, keep_sym=True) + norm_conf = NormalizationConfig(keep_punct = False, keep_sym = True, no_ascii = False) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_keepsym = 'Coucou 😅' self.assertEqual(results, expected_keepsym) @@ -606,16 +621,18 @@ def test_normalise(self): # gives: 'Coucou 😵 💫' # because of the way NFC normalisation is handled probably - # Test for Armenian + # Test for Armenian # GIVEN text = " քան զսակաւս ։ Ահա նշանագրեցի" # WHEN - results = superstyl.preproc.pipe.normalise(text, no_ascii=True) + norm_conf = NormalizationConfig(keep_punct = False, keep_sym = False, no_ascii = True) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_default = "քան զսակաւս ահա նշանագրեցի" self.assertEqual(results, expected_default) # WHEN - results = superstyl.preproc.pipe.normalise(text, keep_punct=True, no_ascii=True) + norm_conf = NormalizationConfig(keep_punct = True, keep_sym = False, no_ascii = True) + results = superstyl.preproc.pipe.normalise(text, norm_conf) # THEN expected_keeppunct = "քան զսակաւս ։ Ահա նշանագրեցի" self.assertEqual(results, expected_keeppunct) diff --git a/tests/test_load_from_config.py b/tests/test_load_from_config.py index 678230ca..633c7688 100644 --- a/tests/test_load_from_config.py +++ b/tests/test_load_from_config.py @@ -6,8 +6,8 @@ import sys import glob -from superstyl.load_from_config import load_corpus_from_config - +from superstyl.load import load_corpus +from superstyl.config import Config # Add parent directory to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -25,8 +25,11 @@ def setUp(self): # Create a test configuration self.config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt", + "identify_lang": False + }, "sampling": { "enabled": False }, @@ -48,8 +51,10 @@ def setUp(self): # Create a single feature configuration for testing self.single_config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt", + }, "features": [ { "name": "words", @@ -77,9 +82,8 @@ def test_load_corpus_from_json_config_multiple_features(self): # GIVEN: Config file with multiple feature specifications # WHEN: Loading corpus from config - corpus, features = load_corpus_from_config( - config_path=self.json_config_path - ) + json_config = Config.from_json(self.json_config_path) + corpus, features = load_corpus(config=json_config) # THEN: Corpus and features are loaded correctly self.assertIsInstance(corpus, pd.DataFrame) @@ -97,9 +101,8 @@ def test_load_corpus_single_feature(self): # GIVEN: Config file with a single feature specification # WHEN: Loading corpus from config - corpus, features = load_corpus_from_config( - config_path=self.single_config_path - ) + config = Config.from_json(self.single_config_path) + corpus, features = load_corpus(config=config) # THEN: Corpus and features are loaded correctly without prefix self.assertIsInstance(corpus, pd.DataFrame) @@ -115,16 +118,15 @@ def test_load_corpus_with_sampling(self): # GIVEN: Config with sampling enabled and sample size defined sampling_config = self.config.copy() sampling_config["sampling"]["enabled"] = True - sampling_config["sampling"]["sample_size"] = 2 + sampling_config["sampling"]["size"] = 2 sampling_config_path = os.path.join(self.temp_dir.name, "sampling_config.json") with open(sampling_config_path, 'w') as f: json.dump(sampling_config, f) # WHEN: Loading corpus from config with sampling - corpus, features = load_corpus_from_config( - config_path=sampling_config_path - ) + config = Config.from_json(sampling_config_path) + corpus, features = load_corpus(config=config) # THEN: Samples are created and file names contain segment info first_corpus_index = corpus.index[0] @@ -141,16 +143,15 @@ def test_load_corpus_with_feature_list(self): # Update config to use feature list feature_list_config = self.single_config.copy() - feature_list_config["features"][0]["feat_list"] = feature_list_path + feature_list_config["features"][0]["feat_list_path"] = feature_list_path config_path = os.path.join(self.temp_dir.name, "feature_list_config.json") with open(config_path, 'w') as f: json.dump(feature_list_config, f) # WHEN: Loading corpus from config with feature list - corpus, features = load_corpus_from_config( - config_path=config_path - ) + config = Config.from_json(config_path) + corpus, features = load_corpus(config=config) # THEN: Only features from the predefined list are used feature_words = [f[0] for f in features] @@ -169,8 +170,8 @@ def test_invalid_config_format(self): f.write("invalid: format") # WHEN/THEN: Loading corpus from invalid config raises ValueError - with self.assertRaises(ValueError): - load_corpus_from_config(invalid_path) + with self.assertRaises(Exception): # Will be json.JSONDecodeError + Config.from_json(invalid_path) def test_feature_list_path_not_found(self): """Test handling of a non-existent feature list path""" @@ -178,7 +179,7 @@ def test_feature_list_path_not_found(self): nonexistent_path = os.path.join(self.temp_dir.name, "nonexistent.json") feature_list_config = self.single_config.copy() - feature_list_config["features"][0]["feat_list"] = nonexistent_path + feature_list_config["features"][0]["feat_list_path"] = nonexistent_path config_path = os.path.join(self.temp_dir.name, "missing_feature_list_config.json") with open(config_path, 'w') as f: @@ -186,30 +187,38 @@ def test_feature_list_path_not_found(self): # Should raise FileNotFoundError when trying to load non-existent feature list with self.assertRaises(FileNotFoundError): - load_corpus_from_config(config_path) + Config.from_json(config_path) def test_missing_features_in_config(self): """Test handling of config with no features specified""" # Create config without features key no_features_config = { - "paths": self.test_paths, - "format": "txt" + "corpus": { + "paths": self.test_paths, + "format": "txt" + } } config_path = os.path.join(self.temp_dir.name, "no_features_config.json") with open(config_path, 'w') as f: json.dump(no_features_config, f) - # Should raise ValueError when no features are specified - with self.assertRaises(ValueError): - load_corpus_from_config(config_path) + # Should use default feature when no features are specified + config = Config.from_json(config_path) + corpus, features = load_corpus(config=config) + + # Should work with default features + self.assertIsNotNone(corpus) + self.assertIsNotNone(features) def test_empty_features_list(self): """Test handling of config with empty features list""" # Create config with empty features list empty_features_config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt", + }, "features": [] } @@ -219,13 +228,16 @@ def test_empty_features_list(self): # Should raise ValueError when features list is empty with self.assertRaises(ValueError): - load_corpus_from_config(config_path) + config = Config.from_json(config_path) + config.validate() def test_missing_paths_in_config(self): """Test handling of config with missing paths""" # Create config without paths key no_paths_config = { - "format": "txt", + "corpus": { + "format": "txt", + }, "features": [ { "type": "words", @@ -239,24 +251,9 @@ def test_missing_paths_in_config(self): with open(config_path, 'w') as f: json.dump(no_paths_config, f) - # Should raise ValueError when no paths are specified - with self.assertRaises(ValueError): - load_corpus_from_config(config_path) - - def test_string_path_in_config(self): - """Test handling of config with a string path instead of list""" - # Create config with a string path - string_path_config = self.single_config.copy() - string_path_config["paths"] = self.test_paths[0] # Single string path - - config_path = os.path.join(self.temp_dir.name, "string_path_config.json") - with open(config_path, 'w') as f: - json.dump(string_path_config, f) - - # Should handle string path - corpus, features = load_corpus_from_config(config_path) - self.assertIsInstance(corpus, pd.DataFrame) - self.assertGreater(len(features), 0) + # Config should load but paths will be empty list + config = Config.from_json(config_path) + self.assertEqual(config.corpus.paths, []) def test_feature_with_txt_list(self): """Test loading a feature list from a txt file""" @@ -267,14 +264,15 @@ def test_feature_with_txt_list(self): # Create config that uses txt feature list txt_list_config = self.single_config.copy() - txt_list_config["features"][0]["feat_list"] = feature_list_path + txt_list_config["features"][0]["feat_list_path"] = feature_list_path config_path = os.path.join(self.temp_dir.name, "txt_list_config.json") with open(config_path, 'w') as f: json.dump(txt_list_config, f) # Should load corpus with feature list from txt - corpus, features = load_corpus_from_config(config_path) + config = Config.from_json(config_path) + corpus, features = load_corpus(config=config) self.assertIsInstance(corpus, pd.DataFrame) self.assertGreater(len(features), 0) @@ -282,19 +280,23 @@ def test_all_optional_parameters(self): """Test with a config that includes all optional parameters""" # Create config with all optional parameters full_config = { - "paths": self.test_paths, - "format": "txt", - "keep_punct": True, - "keep_sym": True, - "no_ascii": True, - "identify_lang": True, + "corpus": { + "paths": self.test_paths, + "format": "txt", + "identify_lang": True, + }, + "normalization": { + "keep_punct": True, + "keep_sym": True, + "no_ascii": True, + }, "sampling": { "enabled": True, "units": "words", - "sample_size": 2, - "sample_step": None, # Set to None when sample_random is True + "size": 2, + "step": None, "max_samples": 3, - "sample_random": True + "random": True }, "features": [ { @@ -303,10 +305,6 @@ def test_all_optional_parameters(self): "n": 1, "k": 100, "freq_type": "relative", - "keep_punct": True, - "keep_sym": True, - "no_ascii": True, - "identify_lang": True, "embedding": None, "neighbouring_size": 5, "culling": 0 @@ -319,7 +317,8 @@ def test_all_optional_parameters(self): json.dump(full_config, f) # Should load corpus with all parameters set - corpus, features = load_corpus_from_config(config_path) + config = Config.from_json(config_path) + corpus, features = load_corpus(config=config) self.assertIsInstance(corpus, pd.DataFrame) self.assertGreater(len(features), 0) @@ -340,13 +339,15 @@ def test_feature_list_loading(self): # Test the JSON feature list loading json_config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt", + }, "features": [ { "type": "words", "n": 1, - "feat_list": json_feature_path + "feat_list_path": json_feature_path } ] } @@ -355,21 +356,21 @@ def test_feature_list_loading(self): with open(json_config_path, 'w') as f: json.dump(json_config, f) - # Use direct debugging output to confirm code path is being exercised - print(f"\nTesting JSON feature list: {json_feature_path}") - print(f"JSON feature list content: {json_feature_list}") - - corpus_json, features_json = load_corpus_from_config(json_config_path) + # Load and test JSON config + json_config_obj = Config.from_json(json_config_path) + corpus_json, features_json = load_corpus(config=json_config_obj) # Now test the TXT feature list loading txt_config = { - "paths": self.test_paths, - "format": "txt", + "corpus": { + "paths": self.test_paths, + "format": "txt", + }, "features": [ { "type": "words", "n": 1, - "feat_list": txt_feature_path + "feat_list_path": txt_feature_path } ] } @@ -378,12 +379,9 @@ def test_feature_list_loading(self): with open(txt_config_path, 'w') as f: json.dump(txt_config, f) - # Use direct debugging output to confirm code path is being exercised - print(f"\nTesting TXT feature list: {txt_feature_path}") - with open(txt_feature_path, 'r') as f: - print(f"TXT feature list content: {f.read()}") - - corpus_txt, features_txt = load_corpus_from_config(txt_config_path) + # Load and test TXT config + txt_config_obj = Config.from_json(txt_config_path) + corpus_txt, features_txt = load_corpus(config=txt_config_obj) # Basic verification self.assertIsInstance(corpus_json, pd.DataFrame) @@ -391,20 +389,29 @@ def test_feature_list_loading(self): def test_invalid_paths_type(self): """Test handling of config with invalid paths type (neither list nor string)""" - # Create config with invalid paths type (integer) - invalid_paths_config = self.single_config.copy() - invalid_paths_config["paths"] = 123 # Not a list or string - - config_path = os.path.join(self.temp_dir.name, "invalid_paths_config.json") - with open(config_path, 'w') as f: - json.dump(invalid_paths_config, f) + # For this test, we need to test at the load_corpus level since + # Config.from_dict will accept any type for paths + + # Create a config dict with integer paths (will pass Config creation) + invalid_config_dict = { + "corpus": { + "paths": 123, # Invalid type + "format": "txt" + }, + "features": [ + { + "type": "words", + "n": 1 + } + ] + } - # Should raise ValueError for invalid paths type - with self.assertRaises(ValueError) as context: - load_corpus_from_config(config_path) + # Create Config object (this will work) + config = Config.from_dict(invalid_config_dict) - # Verify the error message - self.assertIn("Paths in config must be either a list or a glob pattern string", str(context.exception)) + # But load_corpus should fail when trying to iterate paths + with self.assertRaises(TypeError): + load_corpus(config=config) if __name__ == "__main__": diff --git a/tests/test_select.py b/tests/test_select.py new file mode 100644 index 00000000..b6dcf916 --- /dev/null +++ b/tests/test_select.py @@ -0,0 +1,473 @@ +import unittest +import tempfile +import os +import csv +import json +import pandas as pd +from superstyl.preproc.select import ( + _load_metadata, + _should_exclude, + read_clean, + apply_selection +) + + +class TestSelectRefactored(unittest.TestCase): + """Tests for the refactored select.py module""" + + def setUp(self): + """Create temporary directory and test files""" + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_path = self.temp_dir.name + + # Create a sample CSV file + self.csv_path = os.path.join(self.temp_path, "test_data.csv") + with open(self.csv_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Unnamed: 0', 'author', 'lang', 'feature1', 'feature2']) + writer.writerow(['doc1', 'Smith', 'en', 0.5, 0.3]) + writer.writerow(['doc2', 'Dupont', 'fr', 0.4, 0.6]) + writer.writerow(['doc3', 'Garcia', 'es', 0.7, 0.2]) + writer.writerow(['doc4', 'Smith', 'en', 0.6, 0.4]) + writer.writerow(['doc5', 'Dupont', 'fr', 0.3, 0.7]) + + def tearDown(self): + """Clean up temporary directory""" + self.temp_dir.cleanup() + + # ========================================================================= + # Tests for _load_metadata() helper function + # ========================================================================= + + def test_load_metadata_with_no_params(self): + """Test _load_metadata when no metadata or excludes needed""" + # WHEN + metadata, excludes = _load_metadata( + self.csv_path, + metadata_path=None, + excludes_path=None, + lang=None + ) + + # THEN + self.assertIsNone(metadata) + self.assertIsNone(excludes) + + def test_load_metadata_from_main_csv_for_lang(self): + """Test _load_metadata creating metadata from main CSV when lang specified""" + # WHEN + metadata, excludes = _load_metadata( + self.csv_path, + metadata_path=None, + excludes_path=None, + lang='en' # Lang specified, so metadata needed + ) + + # THEN + self.assertIsNotNone(metadata) + self.assertIsNone(excludes) + self.assertIn('doc1', metadata.index) + self.assertEqual(metadata.loc['doc1', 'lang'], 'en') + + def test_load_metadata_from_separate_file(self): + """Test _load_metadata loading from separate metadata file""" + # GIVEN: Create a metadata file + metadata_path = os.path.join(self.temp_path, "metadata.csv") + with open(metadata_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['id', 'true']) + writer.writerow(['doc1', 'en']) + writer.writerow(['doc2', 'fr']) + + # WHEN + metadata, excludes = _load_metadata( + self.csv_path, + metadata_path=metadata_path, + excludes_path=None, + lang=None + ) + + # THEN + self.assertIsNotNone(metadata) + self.assertIsNone(excludes) + self.assertEqual(metadata.loc['doc1', 'lang'], 'en') + self.assertEqual(metadata.loc['doc2', 'lang'], 'fr') + + def test_load_metadata_with_excludes(self): + """Test _load_metadata loading excludes list""" + # GIVEN: Create an excludes file + excludes_path = os.path.join(self.temp_path, "excludes.csv") + with open(excludes_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['id']) + writer.writerow(['doc3']) + writer.writerow(['doc5']) + + # WHEN + metadata, excludes = _load_metadata( + self.csv_path, + metadata_path=None, + excludes_path=excludes_path, + lang=None + ) + + # THEN + # When excludes_path is not None, metadata is loaded from CSV (for potential filtering) + self.assertIsNotNone(metadata) + self.assertIsNotNone(excludes) + self.assertIn('doc3', excludes) + self.assertIn('doc5', excludes) + self.assertEqual(len(excludes), 2) + + # ========================================================================= + # Tests for _should_exclude() helper function + # ========================================================================= + + def test_should_exclude_no_exclusions(self): + """Test _should_exclude when no exclusions apply""" + # WHEN + is_excluded, reason = _should_exclude( + 'doc1', + metadata=None, + excludes=None, + lang=None + ) + + # THEN + self.assertFalse(is_excluded) + self.assertIsNone(reason) + + def test_should_exclude_by_language(self): + """Test _should_exclude excluding by language""" + # GIVEN: Create metadata + metadata = pd.DataFrame({ + 'lang': ['en', 'fr', 'es'] + }, index=['doc1', 'doc2', 'doc3']) + + # WHEN: Check doc with wrong language + is_excluded, reason = _should_exclude( + 'doc2', # fr + metadata=metadata, + excludes=None, + lang='en' # Only want en + ) + + # THEN + self.assertTrue(is_excluded) + self.assertIn('not in: en', reason) + self.assertIn('doc2', reason) + + def test_should_exclude_by_language_passes(self): + """Test _should_exclude NOT excluding when language matches""" + # GIVEN: Create metadata + metadata = pd.DataFrame({ + 'lang': ['en', 'fr', 'es'] + }, index=['doc1', 'doc2', 'doc3']) + + # WHEN: Check doc with correct language + is_excluded, reason = _should_exclude( + 'doc1', # en + metadata=metadata, + excludes=None, + lang='en' # Want en + ) + + # THEN + self.assertFalse(is_excluded) + self.assertIsNone(reason) + + def test_should_exclude_by_excludes_list(self): + """Test _should_exclude excluding by excludes list""" + # WHEN + is_excluded, reason = _should_exclude( + 'doc3', + metadata=None, + excludes=['doc3', 'doc5'], + lang=None + ) + + # THEN + self.assertTrue(is_excluded) + self.assertIn('Wilhelmus', reason) + self.assertIn('doc3', reason) + + def test_should_exclude_not_in_excludes_list(self): + """Test _should_exclude NOT excluding when not in list""" + # WHEN + is_excluded, reason = _should_exclude( + 'doc1', + metadata=None, + excludes=['doc3', 'doc5'], + lang=None + ) + + # THEN + self.assertFalse(is_excluded) + self.assertIsNone(reason) + + def test_should_exclude_keyerror_handling(self): + """Test _should_exclude handling missing keys gracefully""" + # GIVEN: Metadata without doc99 + metadata = pd.DataFrame({ + 'lang': ['en', 'fr'] + }, index=['doc1', 'doc2']) + + # WHEN: Check a doc not in metadata + is_excluded, reason = _should_exclude( + 'doc99', # Not in metadata + metadata=metadata, + excludes=None, + lang='en' + ) + + # THEN: Should not crash, should not exclude + self.assertFalse(is_excluded) + self.assertIsNone(reason) + + # ========================================================================= + # Tests for read_clean() function + # ========================================================================= + + def test_read_clean_no_split(self): + """Test read_clean without splitting""" + # GIVEN + output_json = os.path.join(self.temp_path, "selection.json") + + # WHEN + read_clean( + self.csv_path, + savesplit=output_json, + split=False + ) + + # THEN: Should create _selected.csv + selected_path = self.csv_path.replace('.csv', '_selected.csv') + self.assertTrue(os.path.exists(selected_path)) + + # Check content + with open(selected_path, 'r') as f: + lines = f.readlines() + self.assertEqual(len(lines), 6) # Header + 5 data rows + + # Check selection JSON + with open(output_json, 'r') as f: + selection = json.load(f) + self.assertIn('train', selection) + self.assertIn('elim', selection) + self.assertNotIn('valid', selection) # No split + self.assertEqual(len(selection['train']), 5) + self.assertEqual(len(selection['elim']), 0) + + def test_read_clean_with_split(self): + """Test read_clean with splitting into train/valid""" + # GIVEN + output_json = os.path.join(self.temp_path, "selection_split.json") + + # WHEN + read_clean( + self.csv_path, + savesplit=output_json, + split=True, + split_ratio=0.5 # 50% for easier testing + ) + + # THEN: Should create _train.csv and _valid.csv + train_path = self.csv_path.replace('.csv', '_train.csv') + valid_path = self.csv_path.replace('.csv', '_valid.csv') + self.assertTrue(os.path.exists(train_path)) + self.assertTrue(os.path.exists(valid_path)) + + # Check selection JSON + with open(output_json, 'r') as f: + selection = json.load(f) + self.assertIn('train', selection) + self.assertIn('valid', selection) + self.assertIn('elim', selection) + # With 50% split, should have some in train and some in valid + total = len(selection['train']) + len(selection['valid']) + self.assertEqual(total, 5) # All 5 docs processed + + def test_read_clean_with_language_filter(self): + """Test read_clean filtering by language""" + # GIVEN + output_json = os.path.join(self.temp_path, "selection_lang.json") + + # WHEN: Only keep English docs + read_clean( + self.csv_path, + savesplit=output_json, + lang='en', + split=False + ) + + # THEN + with open(output_json, 'r') as f: + selection = json.load(f) + # Should keep 2 English docs (doc1, doc4) + self.assertEqual(len(selection['train']), 2) + # Should eliminate 3 non-English docs (doc2, doc3, doc5) + self.assertEqual(len(selection['elim']), 3) + + def test_read_clean_with_excludes(self): + """Test read_clean with excludes list""" + # GIVEN: Create an excludes file + excludes_path = os.path.join(self.temp_path, "excludes.csv") + with open(excludes_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['id']) + writer.writerow(['doc2']) + writer.writerow(['doc4']) + + output_json = os.path.join(self.temp_path, "selection_excl.json") + + # WHEN + read_clean( + self.csv_path, + excludes_path=excludes_path, + savesplit=output_json, + split=False + ) + + # THEN + with open(output_json, 'r') as f: + selection = json.load(f) + # Should keep 3 docs (doc1, doc3, doc5) + self.assertEqual(len(selection['train']), 3) + # Should eliminate 2 docs (doc2, doc4) + self.assertEqual(len(selection['elim']), 2) + self.assertIn('doc2', selection['elim']) + self.assertIn('doc4', selection['elim']) + + def test_read_clean_with_metadata_file(self): + """Test read_clean with separate metadata file""" + # GIVEN: Create a metadata file + metadata_path = os.path.join(self.temp_path, "metadata.csv") + with open(metadata_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['id', 'true']) + writer.writerow(['doc1', 'en']) + writer.writerow(['doc2', 'en']) + writer.writerow(['doc3', 'fr']) + writer.writerow(['doc4', 'en']) + writer.writerow(['doc5', 'fr']) + + output_json = os.path.join(self.temp_path, "selection_meta.json") + + # WHEN: Filter by English using metadata file + read_clean( + self.csv_path, + metadata_path=metadata_path, + savesplit=output_json, + lang='en', + split=False + ) + + # THEN + with open(output_json, 'r') as f: + selection = json.load(f) + # Should keep 3 English docs according to metadata + self.assertEqual(len(selection['train']), 3) + self.assertIn('doc1', selection['train']) + self.assertIn('doc2', selection['train']) + self.assertIn('doc4', selection['train']) + + def test_read_clean_combined_filters(self): + """Test read_clean with both language filter and excludes""" + # GIVEN: Create an excludes file + excludes_path = os.path.join(self.temp_path, "excludes_combined.csv") + with open(excludes_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['id']) + writer.writerow(['doc4']) # Exclude one English doc + + output_json = os.path.join(self.temp_path, "selection_combined.json") + + # WHEN: Filter by English AND exclude doc4 + read_clean( + self.csv_path, + excludes_path=excludes_path, + savesplit=output_json, + lang='en', + split=False + ) + + # THEN + with open(output_json, 'r') as f: + selection = json.load(f) + # Should keep only doc1 (doc4 excluded, others wrong lang) + self.assertEqual(len(selection['train']), 1) + self.assertIn('doc1', selection['train']) + # Should eliminate 4 docs + self.assertEqual(len(selection['elim']), 4) + + # ========================================================================= + # Tests for apply_selection() function + # ========================================================================= + + def test_apply_selection(self): + """Test apply_selection applying a pre-existing selection""" + # GIVEN: Create a selection JSON + selection_path = os.path.join(self.temp_path, "presplit.json") + selection = { + 'train': ['doc1', 'doc3'], + 'valid': ['doc2', 'doc4'], + 'elim': ['doc5'] + } + with open(selection_path, 'w') as f: + json.dump(selection, f) + + # WHEN + apply_selection(self.csv_path, selection_path) + + # THEN: Should create _train.csv and _valid.csv + train_path = self.csv_path.replace('.csv', '_train.csv') + valid_path = self.csv_path.replace('.csv', '_valid.csv') + self.assertTrue(os.path.exists(train_path)) + self.assertTrue(os.path.exists(valid_path)) + + # Check train file content + with open(train_path, 'r') as f: + reader = csv.reader(f) + next(reader) # Skip header + train_ids = [row[0] for row in reader] + self.assertEqual(set(train_ids), {'doc1', 'doc3'}) + + # Check valid file content + with open(valid_path, 'r') as f: + reader = csv.reader(f) + next(reader) # Skip header + valid_ids = [row[0] for row in reader] + self.assertEqual(set(valid_ids), {'doc2', 'doc4'}) + + def test_apply_selection_eliminates_correctly(self): + """Test apply_selection correctly eliminates specified docs""" + # GIVEN: Create a selection JSON with eliminations + selection_path = os.path.join(self.temp_path, "presplit_elim.json") + selection = { + 'train': ['doc1', 'doc2'], + 'valid': ['doc3'], + 'elim': ['doc4', 'doc5'] + } + with open(selection_path, 'w') as f: + json.dump(selection, f) + + # WHEN + apply_selection(self.csv_path, selection_path) + + # THEN: Eliminated docs should not appear in either file + train_path = self.csv_path.replace('.csv', '_train.csv') + valid_path = self.csv_path.replace('.csv', '_valid.csv') + + with open(train_path, 'r') as f: + content = f.read() + self.assertNotIn('doc4', content) + self.assertNotIn('doc5', content) + + with open(valid_path, 'r') as f: + content = f.read() + self.assertNotIn('doc4', content) + self.assertNotIn('doc5', content) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/train_svm.py b/train_svm.py index c1a30333..747ff587 100755 --- a/train_svm.py +++ b/train_svm.py @@ -1,27 +1,48 @@ import superstyl.svm +from superstyl.config import Config import pandas import joblib if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Train SVM models for stylometric analysis.") + + # Configuration file + parser.add_argument('--config', action='store', type=str, default=None, + help="Path to a JSON configuration file" + ) + + # Data paths parser.add_argument('train_path', action='store', help="Path to train file", type=str) parser.add_argument('--test_path', action='store', help="Path to test file", type=str, required=False, default=None) + + # Output options parser.add_argument('-o', action='store', help="optional prefix for output files", type=str, default=None) + + + # Cross-validation options parser.add_argument('--cross_validate', action='store', help="perform cross validation (test_path will be used only for final prediction)." "If group_k-fold is chosen, each source file will be considered a group " "(only relevant if sampling was performed and more than one file per class was provided)", default=None, choices=['leave-one-out', 'k-fold', 'group-k-fold'], type=str) - parser.add_argument('--k', action='store', help="k for k-fold (default: 10 folds for k-fold; k=n groups for group-k-fold)", default=0, type=int) - parser.add_argument('--dim_reduc', action='store', choices=['pca'], help="optional dimensionality " - "reduction of input data", default=None) + + parser.add_argument('--k', action='store', + help="k for k-fold (default: 10 folds for k-fold; k=n groups for group-k-fold)", + default=0, type=int) + + # Preprocessing options + parser.add_argument('--dim_reduc', action='store', choices=['pca'], + help="Dimensionality reduction of input data", default=None) + parser.add_argument('--norms', action='store_true', help="perform normalisations? (default: True)", default=True) + + # Balancing options parser.add_argument('--balance', action='store', choices=["downsampling", "Tomek", "upsampling", "SMOTE", "SMOTETomek"], help="which " @@ -32,31 +53,43 @@ "SMOTE (upsampling with SMOTE), " "SMOTETomek (over+undersampling with SMOTE+Tomek)", default=None) + parser.add_argument('--class_weights', action='store_true', help="whether to use class weights in imbalanced datasets " - "(inversely proportional to total class samples)", - default=False - ) + "(inversely proportional to total class samples)", default=False) + parser.add_argument('--kernel', action='store', help="type of kernel to use (default and recommended choice is LinearSVC; " "possible alternatives are linear, sigmoid, rbf and poly, as per sklearn.svm.SVC)", default="LinearSVC", choices=['LinearSVC', 'linear', 'sigmoid', 'rbf', 'poly'], type=str) + + # Output options parser.add_argument('--final', action='store_true', help="final analysis on unknown dataset (no evaluation)?", default=False) + parser.add_argument('--get_coefs', action='store_true', help="switch to write to disk and plot the most important coefficients" " for the training feats for each class", default=False) - # New arguments for rolling stylometry plotting + + # Rolling stylometry plotting parser.add_argument('--plot_rolling', action='store_true', help="If final predictions are produced, also plot rolling stylometry.") + parser.add_argument('--plot_smoothing', action='store', type=int, default=3, help="Smoothing window size for rolling stylometry plot (default:3)." "Set to 0 or None to disable smoothing.") + + # Save configuration + parser.add_argument( + '--save_config', action='store', type=str, default=None, + help="Save the configuration to a JSON file" + ) args = parser.parse_args() + # Load data print(".......... loading data ........") train = pandas.read_csv(args.train_path, index_col=0) @@ -65,31 +98,49 @@ else: test = None - svm = superstyl.svm.train_svm(train, test, cross_validate=args.cross_validate, k=args.k, dim_reduc=args.dim_reduc, - norms=args.norms, balance=args.balance, class_weights=args.class_weights, - kernel=args.kernel, final_pred=args.final, get_coefs=args.get_coefs) + # Determine how to run the SVM + if args.config: + config = Config.from_json(args.config) + # Override with CLI arguments if provided - if args.o is not None: - args.o = args.o + "_" else: - args.o = '' + config = Config.from_kwargs( + cross_validate=args.cross_validate, + dim_reduc=args.dim_reduc, + norms=args.norms, + balance=args.balance, + class_weights=args.class_weights, + kernel=args.kernel, + final_pred=args.final, + get_coefs=args.get_coefs + ) + + # Save config if requested + if args.save_config: + config.save(args.save_config) + print(f"Configuration saved to {args.save_config}") + + # Train SVM + svm = superstyl.svm.train_svm(train, test, config=config) + + # Output prefix + prefix = f"{args.o}_" if args.o else "" - if args.cross_validate is not None or (args.test_path is not None and not args.final): - svm["confusion_matrix"].to_csv(args.o+"confusion_matrix.csv") - svm["misattributions"].to_csv(args.o+"misattributions.csv") + # Save results + if args.cross_validate or (args.test_path and not args.final): + svm["confusion_matrix"].to_csv(f"{prefix}confusion_matrix.csv") + svm["misattributions"].to_csv(f"{prefix}misattributions.csv") - joblib.dump(svm["pipeline"], args.o+'mySVM.joblib') + joblib.dump(svm["pipeline"], f"{prefix}mySVM.joblib") if args.final: - print(".......... Writing final predictions to " + args.o + "FINAL_PREDICTIONS.csv ........") - svm["final_predictions"].to_csv(args.o+"FINAL_PREDICTIONS.csv") + print(f".......... Writing final predictions to {prefix}FINAL_PREDICTIONS.csv ........") + svm["final_predictions"].to_csv(f"{prefix}FINAL_PREDICTIONS.csv") - # If user requested rolling stylometry plot if args.plot_rolling: print(".......... Plotting rolling stylometry ........") - smoothing = args.plot_smoothing if args.plot_smoothing is not None else 0 - superstyl.svm.plot_rolling(svm["final_predictions"], smoothing=smoothing) + superstyl.svm.plot_rolling(svm["final_predictions"], smoothing=args.plot_smoothing) if args.get_coefs: print(".......... Writing coefficients to disk ........") - svm["coefficients"].to_csv(args.o+"coefficients.csv") + svm["coefficients"].to_csv(f"{prefix}coefficients.csv") \ No newline at end of file