diff --git a/docs/explanations/names_and_concepts.md b/docs/explanations/names_and_concepts.md
index 6db976a3..f1b7b554 100644
--- a/docs/explanations/names_and_concepts.md
+++ b/docs/explanations/names_and_concepts.md
@@ -96,17 +96,16 @@ tuning knob differs between estimators.
`AFEstimationOptions` (from `skillmodels.af`) controls the sequential MLE:
- **n_halton_points**, **n_halton_points_shock**: quadrature counts.
-- **n_mixture_components**: number of components in the latent-factor mixture.
- **optimizer_algorithm**: the optimagic algorithm name passed to
`optimagic.minimize(algorithm=...)` (default `"fides"`; use
`"scipy_lbfgsb"` for MC sweeps).
-- **initialization_strategy**: `"amn"`, `"spearman"`, or `"constant"`. Same
- meaning as in CHS.
+- **start_params_strategy**: `"amn"`, `"spearman"`, `"constant"`, or `"none"`.
+ Same meaning as in CHS. (The mixture-component count is not an estimator option;
+ it is the structural field `ModelSpec.n_mixtures`.)
`AMNEstimationOptions` (from `skillmodels.amn`) controls the three-stage
pipeline:
-- **n_mixture_components**: Stage-1 EM components.
- **em_max_iter**, **em_tol**, **em_n_init**, **em_reg_covar**: Stage-1 EM
numerical knobs.
- **n_simulation_draws**: Stage-3 synthetic-panel size.
@@ -115,17 +114,11 @@ pipeline:
full covariance matrices, and is currently the only implemented option.
`"optimal"` is reserved for a future Avar-weighted criterion and raises
`NotImplementedError`.
-- **investment_endogeneity**: apply the AMN (2020) eq. 7-8 / AF Sec. 3.5
- investment control-function correction in Stage 3. Defaults to `False`. When
- `True` and the model has an endogenous (investment) factor, a first-stage
- investment equation is OLS-fit per period and its residual is added as an
- additive `cf` covariate (coefficient `kappa_t`) to each state factor's
- production regression; observed factors are then excluded from the production
- function and act as instruments (at least one observed instrument is
- required). The default stays `False` because `estimate_af` calls
- `estimate_amn` for start values and the AF likelihood implements only
- `kappa=0`; opt into the correction at the application call site. A no-op for
- models without endogenous factors.
+- Investment-endogeneity correction is no longer an estimation-option flag.
+ Attach a `CorrectionSpec` to the endogenous investment `FactorSpec`
+ (`FactorSpec.correction`); its presence triggers the Stage-3 control-function
+ correction. See
+ [Endogeneity Corrections](../reference_guides/endogeneity_corrections.md).
The shared structural field — number of mixture components in the latent
distribution — lives directly on `ModelSpec.n_mixtures`, since it changes the
diff --git a/docs/explanations/notes_on_factor_scales.md b/docs/explanations/notes_on_factor_scales.md
index 8bec8dcb..09e38e80 100644
--- a/docs/explanations/notes_on_factor_scales.md
+++ b/docs/explanations/notes_on_factor_scales.md
@@ -74,17 +74,30 @@ However, we don't have formal identification results for this. **We advise cauti
when using CES or log_CES functions—think carefully about your normalizations rather
than relying on automatic generation.
-## Normalizations and Development Stages
-
-When using development stages (periods with identical transition parameters), the
-normalization requirements change.
-
-The key insight: you can identify scale from the first period of a stage, so no later
-normalizations are needed until the next stage begins.
-
-**Recommendations:**
-- Normalize only in the first period of each stage
-- For the initial stage, normalize the first two periods
-- Use automatic normalizations when working with stages to avoid confusion
-
-This reveals another type of over-normalization in the original CHS paper.
+## Normalizations and Identification
+
+The library distinguishes three separate things, and only the first two are
+mechanical:
+
+1. a **syntactic normalization** you supply through `Normalizations` or
+ `fixed_params`;
+2. an estimator **precheck** that catches some missing initial scale/location
+ anchors — it is a precheck, **not** a proof of identification;
+3. a **transition-family identification argument**, which the library does not
+ establish for arbitrary models.
+
+Because step 3 is on you, there is no single stage-level rule of thumb that is safe
+across transition functions. Use the template that matches your production function:
+
+- **Direct trans-log** (`translog`, `translog_af`): anchor one nonzero loading and
+ one intercept/location for every independently scaled factor-period.
+- **Restricted CES** (`log_ces_af` with $\psi_t = 1$): relative skill/investment
+ scales are identified through the production restrictions, so pinning *every* first
+ loading can impose testable restrictions — follow the CES templates in the
+ estimator-specific guides rather than normalizing mechanically.
+- **Intentionally restricted (original-AMN) benchmark**: a deliberately
+ over-restricted spec used only as a comparison point; label it as such so the extra
+ restrictions are not mistaken for identification requirements.
+
+For custom transitions you must establish identification yourself (or add a
+model-specific diagnostic); the automatic checker will not do it for you.
diff --git a/docs/getting_started/tutorial.ipynb b/docs/getting_started/tutorial.ipynb
index 269397c2..46082e77 100644
--- a/docs/getting_started/tutorial.ipynb
+++ b/docs/getting_started/tutorial.ipynb
@@ -13,9 +13,9 @@
"- **AF** (Antweiler-Freyberger 2025): sequential MLE with Halton quadrature over the latent posterior, period by period.\n",
"- **AMN** (Attanasio-Meghir-Nix 2020): three-stage estimator (EM on a mixture-of-normals → minimum distance → simulate-and-regress).\n",
"\n",
- "This tutorial walks all three on the **CNLSY** dataset from the AF 2025 application: three waves (ages 7 / 9 / 11), a CES production function for the latent skill, and an endogenous investment factor.\n",
+ "This tutorial walks **CHS** and **AF** on the **CNLSY** dataset from the AF 2025 application — three waves (ages 7 / 9 / 11), a CES production function for the latent skill, and an endogenous investment factor. **AMN** appears in its role as the CHS seed: it cannot consistently fit the *restricted* CES standalone, so CHS uses AMN's three stages to build start values (see the AMN section).\n",
"\n",
- "The notebook is committed with pre-rendered outputs (`execute: false`); rerun it to refresh the numbers if the API drifts. Single-machine wallclock at the time of writing: ~10–15 minutes total."
+ "The notebook is committed with pre-rendered outputs (`execute: false`); rerun it to refresh the numbers if the API drifts. Single-machine wallclock (CPU, JAX CPU backend): roughly 1–2 hours, dominated by the CHS MLE; AF and the AMN seeding step are faster."
]
},
{
@@ -30,18 +30,49 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "2",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:13:11.517683Z",
- "iopub.status.busy": "2026-05-14T07:13:11.517454Z",
- "iopub.status.idle": "2026-05-14T07:13:13.015271Z",
- "shell.execute_reply": "2026-05-14T07:13:13.014780Z"
+ "iopub.execute_input": "2026-06-29T08:32:39.752336Z",
+ "iopub.status.busy": "2026-06-29T08:32:39.752128Z",
+ "iopub.status.idle": "2026-06-29T08:32:42.714390Z",
+ "shell.execute_reply": "2026-06-29T08:32:42.713366Z"
}
},
- "outputs": [],
- "source": "import warnings\n\nimport optimagic as om\nimport pandas as pd\nimport plotly.graph_objects as go\n\nfrom skillmodels import FactorSpec, ModelSpec, Normalizations\nfrom skillmodels.af import AFEstimationOptions, estimate_af\nfrom skillmodels.af.posterior_states import get_af_posterior_states\nfrom skillmodels.amn import AMNEstimationOptions, estimate_amn\nfrom skillmodels.amn.posterior_states import get_amn_posterior_states\nfrom skillmodels.chs import (\n CHSEstimationOptions,\n get_maximization_inputs,\n)\nfrom skillmodels.common.config import CNLSY_DATA_PATH\nfrom skillmodels.common.individual_states import get_individual_states_from_params\nfrom skillmodels.common.variance_decomposition import (\n decompose_measurement_variance,\n summarize_measurement_reliability,\n)\n\nwarnings.filterwarnings(\"ignore\", category=DeprecationWarning)\npd.options.display.float_format = \"{:.3f}\".format"
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import warnings\n",
+ "\n",
+ "import optimagic as om\n",
+ "import pandas as pd\n",
+ "import plotly.graph_objects as go\n",
+ "\n",
+ "from skillmodels import FactorSpec, ModelSpec, Normalizations\n",
+ "from skillmodels.af import AFEstimationOptions, estimate_af\n",
+ "from skillmodels.af.posterior_states import get_af_posterior_states\n",
+ "from skillmodels.chs import (\n",
+ " CHSEstimationOptions,\n",
+ " get_maximization_inputs,\n",
+ ")\n",
+ "from skillmodels.common.config import CNLSY_DATA_PATH\n",
+ "from skillmodels.common.individual_states import get_individual_states_from_params\n",
+ "from skillmodels.common.variance_decomposition import (\n",
+ " decompose_measurement_variance,\n",
+ " summarize_measurement_reliability,\n",
+ ")\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
+ "pd.options.display.float_format = \"{:.3f}\".format"
+ ]
},
{
"cell_type": "code",
@@ -49,10 +80,10 @@
"id": "3",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:13:13.016446Z",
- "iopub.status.busy": "2026-05-14T07:13:13.016229Z",
- "iopub.status.idle": "2026-05-14T07:13:13.027837Z",
- "shell.execute_reply": "2026-05-14T07:13:13.027469Z"
+ "iopub.execute_input": "2026-06-29T08:32:42.717004Z",
+ "iopub.status.busy": "2026-06-29T08:32:42.716582Z",
+ "iopub.status.idle": "2026-06-29T08:32:42.738749Z",
+ "shell.execute_reply": "2026-06-29T08:32:42.737986Z"
}
},
"outputs": [
@@ -79,10 +110,10 @@
"id": "4",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:13:13.029013Z",
- "iopub.status.busy": "2026-05-14T07:13:13.028919Z",
- "iopub.status.idle": "2026-05-14T07:13:13.036687Z",
- "shell.execute_reply": "2026-05-14T07:13:13.036255Z"
+ "iopub.execute_input": "2026-06-29T08:32:42.741371Z",
+ "iopub.status.busy": "2026-06-29T08:32:42.741161Z",
+ "iopub.status.idle": "2026-06-29T08:32:42.755812Z",
+ "shell.execute_reply": "2026-06-29T08:32:42.754929Z"
}
},
"outputs": [
@@ -129,9 +160,13 @@
"factors = {\n",
" \"skills\": FactorSpec(\n",
" measurements=_measurements(SKILL_MEASURES),\n",
- " normalizations=_normalizations(SKILL_MEASURES, normalize_periods=(0,)),\n",
+ " normalizations=_normalizations(SKILL_MEASURES, normalize_periods=(0, 1, 2)),\n",
" transition_function=\"log_ces\",\n",
" ),\n",
+ " # Cognitive (MC) and non-cognitive (MN) skills are time-invariant inputs to the\n",
+ " # skills CES. `af_state_role=\"static_persistent\"` tells the AF calendar adapter to\n",
+ " # re-apply their period-0 measurement density as a static importance block at every\n",
+ " # step (a no-op for CHS/AMN, which read the same spec).\n",
" \"MC\": FactorSpec(\n",
" measurements=_measurements(MC_MEASURES, active_periods=(0,)),\n",
" normalizations=_normalizations(\n",
@@ -139,6 +174,7 @@
" ),\n",
" transition_function=\"linear\",\n",
" has_production_shock=False,\n",
+ " af_state_role=\"static_persistent\",\n",
" ),\n",
" \"MN\": FactorSpec(\n",
" measurements=_measurements(MN_MEASURES, active_periods=(0,)),\n",
@@ -147,13 +183,21 @@
" ),\n",
" transition_function=\"linear\",\n",
" has_production_shock=False,\n",
+ " af_state_role=\"static_persistent\",\n",
" ),\n",
+ " # Investment is endogenous and reconstructed from its own measurement equation\n",
+ " # rather than drawn from the period-0 latent mixture, so it sets\n",
+ " # `is_endogenous=True, has_initial_distribution=False`. This activates the AF\n",
+ " # source/destination calendar adapter (period-t indicators measure investment in\n",
+ " # period t).\n",
" \"investment\": FactorSpec(\n",
" measurements=_measurements(INV_MEASURES, active_periods=(0, 1)),\n",
" normalizations=_normalizations(\n",
- " INV_MEASURES, active_periods=(0, 1), normalize_periods=(0,)\n",
+ " INV_MEASURES, active_periods=(0, 1), normalize_periods=(0, 1)\n",
" ),\n",
" transition_function=\"linear\",\n",
+ " is_endogenous=True,\n",
+ " has_initial_distribution=False,\n",
" ),\n",
"}\n",
"\n",
@@ -200,10 +244,10 @@
"id": "6",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:13:13.038140Z",
- "iopub.status.busy": "2026-05-14T07:13:13.037984Z",
- "iopub.status.idle": "2026-05-14T07:13:39.788024Z",
- "shell.execute_reply": "2026-05-14T07:13:39.787579Z"
+ "iopub.execute_input": "2026-06-29T08:32:42.757835Z",
+ "iopub.status.busy": "2026-06-29T08:32:42.757648Z",
+ "iopub.status.idle": "2026-06-29T08:32:48.703368Z",
+ "shell.execute_reply": "2026-06-29T08:32:48.702321Z"
}
},
"outputs": [
@@ -211,48 +255,48 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "params_template shape: (158, 3)\n",
- "n_constraints: 32\n"
+ "params_template shape: (214, 3)\n",
+ "n_constraints: 84\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:431: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = float(features[rows, f_idx].mean())\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:434: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[w_loc, \"value\"] = float(rows.mean())\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
- " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:336: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:339: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
" params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:363: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
- " params.loc[loc, \"value\"] = sd_factor\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:559: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
- " params.loc[loc, \"value\"] = float(beta[col_idx])\n",
- "/home/hmg/econ/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:579: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
- " params.loc[loc_sd, \"value\"] = shock_sd\n"
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:576: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = sd_factor\n"
]
}
],
@@ -277,10 +321,10 @@
"id": "7",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:13:39.789235Z",
- "iopub.status.busy": "2026-05-14T07:13:39.789034Z",
- "iopub.status.idle": "2026-05-14T07:19:28.096303Z",
- "shell.execute_reply": "2026-05-14T07:19:28.095804Z"
+ "iopub.execute_input": "2026-06-29T08:32:48.705394Z",
+ "iopub.status.busy": "2026-06-29T08:32:48.705163Z",
+ "iopub.status.idle": "2026-06-29T09:18:35.831694Z",
+ "shell.execute_reply": "2026-06-29T09:18:35.826818Z"
}
},
"outputs": [
@@ -288,7 +332,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "CHS success: True, loglike: -39620.92\n"
+ "CHS success: True, loglike: -39542.42\n"
]
}
],
@@ -306,90 +350,519 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"id": "8",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:19:28.097610Z",
- "iopub.status.busy": "2026-05-14T07:19:28.097505Z",
- "iopub.status.idle": "2026-05-14T07:19:57.433481Z",
- "shell.execute_reply": "2026-05-14T07:19:57.432980Z"
+ "iopub.execute_input": "2026-06-29T09:18:35.835312Z",
+ "iopub.status.busy": "2026-06-29T09:18:35.834946Z",
+ "iopub.status.idle": "2026-06-29T09:18:53.202115Z",
+ "shell.execute_reply": "2026-06-29T09:18:53.200935Z"
}
},
- "outputs": [],
- "source": "chs_filtered = get_individual_states_from_params(\n model_spec=model, data=data, params=chs_params\n)\nchs_states = chs_filtered[\"unanchored_states\"][\"states\"]\nchs_decomp = decompose_measurement_variance(\n model_spec=model, params=chs_params, filtered_states=chs_states\n)\nchs_reliability = summarize_measurement_reliability(chs_decomp)\nchs_reliability"
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:431: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = float(features[rows, f_idx].mean())\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:434: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[w_loc, \"value\"] = float(rows.mean())\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:552: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_sd, \"value\"] = float(result.meas_sds[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:549: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc_load, \"value\"] = float(result.loadings[local_idx])\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/amn/start_values.py:576: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = sd_factor\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " mean_signal | \n",
+ " min_signal | \n",
+ " max_signal | \n",
+ "
\n",
+ " \n",
+ " | measurement | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | skill_recog | \n",
+ " 0.801 | \n",
+ " 0.739 | \n",
+ " 0.918 | \n",
+ "
\n",
+ " \n",
+ " | mc_2 | \n",
+ " 0.707 | \n",
+ " 0.707 | \n",
+ " 0.707 | \n",
+ "
\n",
+ " \n",
+ " | skill_comp | \n",
+ " 0.703 | \n",
+ " 0.619 | \n",
+ " 0.861 | \n",
+ "
\n",
+ " \n",
+ " | mc_1 | \n",
+ " 0.690 | \n",
+ " 0.690 | \n",
+ " 0.690 | \n",
+ "
\n",
+ " \n",
+ " | mc_3 | \n",
+ " 0.676 | \n",
+ " 0.676 | \n",
+ " 0.676 | \n",
+ "
\n",
+ " \n",
+ " | mc_6 | \n",
+ " 0.673 | \n",
+ " 0.673 | \n",
+ " 0.673 | \n",
+ "
\n",
+ " \n",
+ " | mc_4 | \n",
+ " 0.526 | \n",
+ " 0.526 | \n",
+ " 0.526 | \n",
+ "
\n",
+ " \n",
+ " | skill_math | \n",
+ " 0.489 | \n",
+ " 0.475 | \n",
+ " 0.510 | \n",
+ "
\n",
+ " \n",
+ " | mn_neg | \n",
+ " 0.469 | \n",
+ " 0.469 | \n",
+ " 0.469 | \n",
+ "
\n",
+ " \n",
+ " | mc_5 | \n",
+ " 0.432 | \n",
+ " 0.432 | \n",
+ " 0.432 | \n",
+ "
\n",
+ " \n",
+ " | mn_pos | \n",
+ " 0.425 | \n",
+ " 0.425 | \n",
+ " 0.425 | \n",
+ "
\n",
+ " \n",
+ " | mn_rotter | \n",
+ " 0.084 | \n",
+ " 0.084 | \n",
+ " 0.084 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " mean_signal min_signal max_signal\n",
+ "measurement \n",
+ "skill_recog 0.801 0.739 0.918\n",
+ "mc_2 0.707 0.707 0.707\n",
+ "skill_comp 0.703 0.619 0.861\n",
+ "mc_1 0.690 0.690 0.690\n",
+ "mc_3 0.676 0.676 0.676\n",
+ "mc_6 0.673 0.673 0.673\n",
+ "mc_4 0.526 0.526 0.526\n",
+ "skill_math 0.489 0.475 0.510\n",
+ "mn_neg 0.469 0.469 0.469\n",
+ "mc_5 0.432 0.432 0.432\n",
+ "mn_pos 0.425 0.425 0.425\n",
+ "mn_rotter 0.084 0.084 0.084"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "chs_filtered = get_individual_states_from_params(\n",
+ " model_spec=model, data=data, params=chs_params\n",
+ ")\n",
+ "chs_states = chs_filtered[\"unanchored_states\"][\"states\"]\n",
+ "chs_decomp = decompose_measurement_variance(\n",
+ " model_spec=model, params=chs_params, filtered_states=chs_states\n",
+ ")\n",
+ "chs_reliability = summarize_measurement_reliability(chs_decomp)\n",
+ "chs_reliability"
+ ]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
- "source": "## AF: sequential Halton-quadrature MLE\n\nAF estimates each period in turn: period 0 fits the joint mixture-of-normals + period-0 measurement system; each subsequent period takes the previous period's posterior and runs a period-specific MLE via `optimagic.minimize` with the algorithm in `AFEstimationOptions.optimizer_algorithm` (default `\"fides\"`; pass `\"scipy_lbfgsb\"` for Monte Carlo sweeps where a deterministic stopping rule matters).\n\nDefault `start_params_strategy=\"amn\"` runs the full AMN three-stage estimator upfront and uses its parameters to seed each AF period. The number of mixture components is read from `ModelSpec.n_mixtures` (set to 2 above), shared across all three estimators."
+ "source": [
+ "## AF: sequential Halton-quadrature MLE\n",
+ "\n",
+ "AF estimates each period in turn: period 0 fits the joint mixture-of-normals + period-0 measurement system; each subsequent period takes the previous period's posterior and runs a period-specific MLE via `optimagic.minimize` with the algorithm in `AFEstimationOptions.optimizer_algorithm` (default `\"fides\"`; pass `\"scipy_lbfgsb\"` for Monte Carlo sweeps where a deterministic stopping rule matters).\n",
+ "\n",
+ "Default `start_params_strategy=\"amn\"` runs the full AMN three-stage estimator upfront and uses its parameters to seed each AF period. The number of mixture components is read from `ModelSpec.n_mixtures` (set to 2 above), shared across all three estimators."
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"id": "10",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:19:57.434856Z",
- "iopub.status.busy": "2026-05-14T07:19:57.434760Z",
- "iopub.status.idle": "2026-05-14T07:20:56.673081Z",
- "shell.execute_reply": "2026-05-14T07:20:56.672613Z"
+ "iopub.execute_input": "2026-06-29T09:18:53.205706Z",
+ "iopub.status.busy": "2026-06-29T09:18:53.205420Z",
+ "iopub.status.idle": "2026-06-29T09:22:10.970887Z",
+ "shell.execute_reply": "2026-06-29T09:22:10.969782Z"
}
},
- "outputs": [],
- "source": "af_options = AFEstimationOptions(\n n_halton_points=100,\n n_halton_points_shock=50,\n optimizer_algorithm=\"scipy_lbfgsb\",\n)\naf_result = estimate_af(\n model_spec=model,\n data=data,\n options=af_options,\n fixed_params=fixed_params,\n)\naf_lls = [pr.loglikelihood for pr in af_result.period_results]\nprint(\n f\"AF per-period success: {[bool(pr.success) for pr in af_result.period_results]}, \"\n f\"per-period log-likelihoods: {[f'{ll:.2f}' for ll in af_lls]}\"\n)"
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/estimate.py:345: UserWarning: Restricted-CES scale system ('MC', 'MN', 'skills') period 0: 3 independent scale pins (loadings) across the system, but the CES restrictions identify the relative scales from a SINGLE primitive anchor. 2 of them are testable restrictions, not normalizations.\n",
+ " validate_af_model(model_spec, fixed_params, constraints)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/estimate.py:345: UserWarning: Factor 'skills': built-in transition 'log_ces' enumerates parameters over ALL factors including observed factors ('log_income_observed',), so they enter the production function with free coefficients. The AF model assumes observed factors affect skills only through the investment equation. Use a production-factors-only transition ('translog_af' or 'log_ces_af'), or pin every observed-factor transition coefficient to 0.0 via `fixed_params`, to avoid changing the production estimand.\n",
+ " validate_af_model(model_spec, fixed_params, constraints)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/estimate.py:345: UserWarning: Factor 'MC': built-in transition 'linear' enumerates parameters over ALL factors including observed factors ('log_income_observed',), so they enter the production function with free coefficients. The AF model assumes observed factors affect skills only through the investment equation. Use a production-factors-only transition ('translog_af' or 'log_ces_af'), or pin every observed-factor transition coefficient to 0.0 via `fixed_params`, to avoid changing the production estimand.\n",
+ " validate_af_model(model_spec, fixed_params, constraints)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/estimate.py:345: UserWarning: Factor 'MN': built-in transition 'linear' enumerates parameters over ALL factors including observed factors ('log_income_observed',), so they enter the production function with free coefficients. The AF model assumes observed factors affect skills only through the investment equation. Use a production-factors-only transition ('translog_af' or 'log_ces_af'), or pin every observed-factor transition coefficient to 0.0 via `fixed_params`, to avoid changing the production estimand.\n",
+ " validate_af_model(model_spec, fixed_params, constraints)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:357: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"lower_bound\"] = bounds_distance\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:364: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:365: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"lower_bound\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:366: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"upper_bound\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:372: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"value\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:373: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"lower_bound\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/params.py:374: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[loc, \"upper_bound\"] = val\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/initial_period.py:429: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = obs_sds.get(parts[0], meas_sd * 0.5)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/initial_period.py:431: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 0.0\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/initial_period.py:377: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = max(obs_sd * 0.5, 0.01)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/initial_period.py:383: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 1.0\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/initial_period.py:392: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 0.0\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1079: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 0.5\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1090: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = max(obs_sd * 0.5, 0.01)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1096: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 1.0\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1079: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 0.5\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1090: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = max(obs_sd * 0.5, 0.01)\n",
+ "/home/hmga/skillmodels-applications/skillmodels/src/skillmodels/af/transition_period.py:1096: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
+ " params.loc[idx, \"value\"] = 1.0\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AF per-period success: [True, True, True], per-period log-likelihoods: ['-14.93', '-21.16', '-22.86']\n"
+ ]
+ }
+ ],
+ "source": [
+ "af_options = AFEstimationOptions(\n",
+ " n_halton_points=100,\n",
+ " n_halton_points_shock=50,\n",
+ " optimizer_algorithm=\"scipy_lbfgsb\",\n",
+ ")\n",
+ "af_result = estimate_af(\n",
+ " model_spec=model,\n",
+ " data=data,\n",
+ " options=af_options,\n",
+ " fixed_params=fixed_params,\n",
+ ")\n",
+ "af_lls = [pr.loglikelihood for pr in af_result.period_results]\n",
+ "print(\n",
+ " f\"AF per-period success: {[bool(pr.success) for pr in af_result.period_results]}, \"\n",
+ " f\"per-period log-likelihoods: {[f'{ll:.2f}' for ll in af_lls]}\"\n",
+ ")"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"id": "11",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:20:56.674410Z",
- "iopub.status.busy": "2026-05-14T07:20:56.674311Z",
- "iopub.status.idle": "2026-05-14T07:21:01.658274Z",
- "shell.execute_reply": "2026-05-14T07:21:01.657922Z"
+ "iopub.execute_input": "2026-06-29T09:22:10.972958Z",
+ "iopub.status.busy": "2026-06-29T09:22:10.972726Z",
+ "iopub.status.idle": "2026-06-29T09:22:16.257808Z",
+ "shell.execute_reply": "2026-06-29T09:22:16.256615Z"
}
},
- "outputs": [],
- "source": "af_posterior = get_af_posterior_states(\n af_result=af_result,\n model_spec=model,\n data=data,\n n_halton_points=200,\n)\naf_states = af_posterior[\"unanchored_states\"][\"states\"]\naf_decomp = decompose_measurement_variance(\n model_spec=model, params=af_result.params, filtered_states=af_states\n)\naf_reliability = summarize_measurement_reliability(af_decomp)\naf_reliability"
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " mean_signal | \n",
+ " min_signal | \n",
+ " max_signal | \n",
+ "
\n",
+ " \n",
+ " | measurement | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | skill_recog | \n",
+ " 0.773 | \n",
+ " 0.726 | \n",
+ " 0.819 | \n",
+ "
\n",
+ " \n",
+ " | skill_comp | \n",
+ " 0.714 | \n",
+ " 0.612 | \n",
+ " 0.816 | \n",
+ "
\n",
+ " \n",
+ " | mc_2 | \n",
+ " 0.708 | \n",
+ " 0.708 | \n",
+ " 0.708 | \n",
+ "
\n",
+ " \n",
+ " | mc_3 | \n",
+ " 0.685 | \n",
+ " 0.685 | \n",
+ " 0.685 | \n",
+ "
\n",
+ " \n",
+ " | mc_1 | \n",
+ " 0.676 | \n",
+ " 0.676 | \n",
+ " 0.676 | \n",
+ "
\n",
+ " \n",
+ " | mc_6 | \n",
+ " 0.663 | \n",
+ " 0.663 | \n",
+ " 0.663 | \n",
+ "
\n",
+ " \n",
+ " | mc_4 | \n",
+ " 0.531 | \n",
+ " 0.531 | \n",
+ " 0.531 | \n",
+ "
\n",
+ " \n",
+ " | mc_5 | \n",
+ " 0.435 | \n",
+ " 0.435 | \n",
+ " 0.435 | \n",
+ "
\n",
+ " \n",
+ " | mn_neg | \n",
+ " 0.404 | \n",
+ " 0.404 | \n",
+ " 0.404 | \n",
+ "
\n",
+ " \n",
+ " | mn_pos | \n",
+ " 0.400 | \n",
+ " 0.400 | \n",
+ " 0.400 | \n",
+ "
\n",
+ " \n",
+ " | skill_math | \n",
+ " 0.398 | \n",
+ " 0.377 | \n",
+ " 0.419 | \n",
+ "
\n",
+ " \n",
+ " | mn_rotter | \n",
+ " 0.082 | \n",
+ " 0.082 | \n",
+ " 0.082 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " mean_signal min_signal max_signal\n",
+ "measurement \n",
+ "skill_recog 0.773 0.726 0.819\n",
+ "skill_comp 0.714 0.612 0.816\n",
+ "mc_2 0.708 0.708 0.708\n",
+ "mc_3 0.685 0.685 0.685\n",
+ "mc_1 0.676 0.676 0.676\n",
+ "mc_6 0.663 0.663 0.663\n",
+ "mc_4 0.531 0.531 0.531\n",
+ "mc_5 0.435 0.435 0.435\n",
+ "mn_neg 0.404 0.404 0.404\n",
+ "mn_pos 0.400 0.400 0.400\n",
+ "skill_math 0.398 0.377 0.419\n",
+ "mn_rotter 0.082 0.082 0.082"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af_posterior = get_af_posterior_states(\n",
+ " af_result=af_result,\n",
+ " model_spec=model,\n",
+ " data=data,\n",
+ " n_halton_points=200,\n",
+ ")\n",
+ "af_states = af_posterior[\"unanchored_states\"][\"states\"]\n",
+ "af_decomp = decompose_measurement_variance(\n",
+ " model_spec=model, params=af_result.params, filtered_states=af_states\n",
+ ")\n",
+ "af_reliability = summarize_measurement_reliability(af_decomp)\n",
+ "af_reliability"
+ ]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
- "source": "## AMN: three-stage mixture-of-normals\n\nAMN's three stages are EM on the augmented measurement vector (mixture-of-normals), minimum-distance recovery of the structural parameters, and a simulate-and-regress step on the fitted mixture for the production function. With `ModelSpec.n_mixtures=2` the Stage-1 EM fits a 2-component Gaussian mixture on the joint $(M, X)$ vector — the kind of non-Gaussian latent structure AMN is designed for."
+ "source": [
+ "## AMN: three-stage mixture-of-normals\n",
+ "\n",
+ "AMN's three stages are EM on the augmented measurement vector (mixture-of-normals), minimum-distance recovery of the structural parameters, and a simulate-and-regress step on the fitted mixture for the production function. With `ModelSpec.n_mixtures=2` the Stage-1 EM fits a 2-component Gaussian mixture on the joint $(M, X)$ vector — the kind of non-Gaussian latent structure AMN is designed for.\n",
+ "\n",
+ "On **this** model AMN does not run standalone: `estimate_amn` deliberately refuses the *restricted* CES (`log_ces`) because its Stage-3 regression omits the Freyberger (2025) primitive-scale recovery, so a standalone fit would be inconsistent. Its role here is instead to **seed CHS** — `CHSEstimationOptions` defaults to `start_params_strategy=\"amn\"`, so the CHS fit above already ran AMN's three stages to build its start values. For a standalone AMN fit of a CES production function, use `log_ces_general` (AMN fits the transformed general CES and recovers the primitive scales via `recover_primitive_ces_scales`); see the AMN how-to for AMN on its native mixture-of-normals ground."
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"id": "13",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:21:01.659567Z",
- "iopub.status.busy": "2026-05-14T07:21:01.659466Z",
- "iopub.status.idle": "2026-05-14T07:21:26.698296Z",
- "shell.execute_reply": "2026-05-14T07:21:26.695286Z"
+ "iopub.execute_input": "2026-06-29T09:22:16.260184Z",
+ "iopub.status.busy": "2026-06-29T09:22:16.259967Z",
+ "iopub.status.idle": "2026-06-29T09:22:16.264768Z",
+ "shell.execute_reply": "2026-06-29T09:22:16.263687Z"
}
},
- "outputs": [],
- "source": "amn_options = AMNEstimationOptions(\n em_max_iter=500,\n n_simulation_draws=50_000,\n seed=0,\n)\namn_result = estimate_amn(\n model_spec=model,\n data=data,\n options=amn_options,\n fixed_params=fixed_params,\n)\nprint(f\"AMN success: {amn_result.success}\")"
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "14",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-05-14T07:21:26.699696Z",
- "iopub.status.busy": "2026-05-14T07:21:26.699564Z",
- "iopub.status.idle": "2026-05-14T07:21:26.794622Z",
- "shell.execute_reply": "2026-05-14T07:21:26.793933Z"
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CHS start-value strategy: 'amn' (AMN seeds CHS)\n"
+ ]
}
- },
- "outputs": [],
- "source": "amn_posterior = get_amn_posterior_states(amn_result=amn_result, data=data)\namn_states = amn_posterior[\"unanchored_states\"][\"states\"]\namn_decomp = decompose_measurement_variance(\n model_spec=model, params=amn_result.params, filtered_states=amn_states\n)\namn_reliability = summarize_measurement_reliability(amn_decomp)\namn_reliability"
+ ],
+ "source": [
+ "# AMN cannot consistently fit the restricted CES (`log_ces`) standalone, so here it\n",
+ "# seeds CHS instead: `CHSEstimationOptions` defaults to `start_params_strategy=\"amn\"`,\n",
+ "# i.e. the CHS fit above already ran AMN's three stages to build its start values.\n",
+ "seed_strategy = chs_options.start_params_strategy\n",
+ "print(f\"CHS start-value strategy: {seed_strategy!r} (AMN seeds CHS)\")"
+ ]
},
{
"cell_type": "markdown",
@@ -407,29 +880,118 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"id": "16",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:21:26.795949Z",
- "iopub.status.busy": "2026-05-14T07:21:26.795823Z",
- "iopub.status.idle": "2026-05-14T07:21:26.809737Z",
- "shell.execute_reply": "2026-05-14T07:21:26.809254Z"
+ "iopub.execute_input": "2026-06-29T09:22:16.266706Z",
+ "iopub.status.busy": "2026-06-29T09:22:16.266472Z",
+ "iopub.status.idle": "2026-06-29T09:22:16.293008Z",
+ "shell.execute_reply": "2026-06-29T09:22:16.291814Z"
}
},
- "outputs": [],
- "source": "estimators = {\n \"CHS\": chs_params,\n \"AF\": af_result.params,\n \"AMN\": amn_result.params,\n}\n\n\ndef _free_loading(params, period, meas):\n loc = (\"loadings\", period, meas, \"skills\")\n if loc not in params.index:\n return None\n return float(params.loc[loc, \"value\"])\n\n\nloading_rows = []\nfor est, params in estimators.items():\n for meas in SKILL_MEASURES:\n value = _free_loading(params, period=0, meas=meas)\n if value is None:\n continue\n loading_rows.append({\"estimator\": est, \"measurement\": meas, \"loading\": value})\nloadings_df = pd.DataFrame(loading_rows)\nloadings_pivot = loadings_df.pivot(\n index=\"measurement\", columns=\"estimator\", values=\"loading\"\n)\nloadings_pivot"
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | estimator | \n",
+ " AF | \n",
+ " CHS | \n",
+ "
\n",
+ " \n",
+ " | measurement | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | skill_comp | \n",
+ " 1.373 | \n",
+ " 1.496 | \n",
+ "
\n",
+ " \n",
+ " | skill_math | \n",
+ " 1.000 | \n",
+ " 1.000 | \n",
+ "
\n",
+ " \n",
+ " | skill_recog | \n",
+ " 1.376 | \n",
+ " 1.564 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "estimator AF CHS\n",
+ "measurement \n",
+ "skill_comp 1.373 1.496\n",
+ "skill_math 1.000 1.000\n",
+ "skill_recog 1.376 1.564"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "estimators = {\n",
+ " \"CHS\": chs_params,\n",
+ " \"AF\": af_result.params,\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def _free_loading(params, period, meas):\n",
+ " loc = (\"loadings\", period, meas, \"skills\")\n",
+ " if loc not in params.index:\n",
+ " return None\n",
+ " return float(params.loc[loc, \"value\"])\n",
+ "\n",
+ "\n",
+ "loading_rows = []\n",
+ "for est, params in estimators.items():\n",
+ " for meas in SKILL_MEASURES:\n",
+ " value = _free_loading(params, period=0, meas=meas)\n",
+ " if value is None:\n",
+ " continue\n",
+ " loading_rows.append({\"estimator\": est, \"measurement\": meas, \"loading\": value})\n",
+ "loadings_df = pd.DataFrame(loading_rows)\n",
+ "loadings_pivot = loadings_df.pivot(\n",
+ " index=\"measurement\", columns=\"estimator\", values=\"loading\"\n",
+ ")\n",
+ "loadings_pivot"
+ ]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"id": "17",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:21:26.810916Z",
- "iopub.status.busy": "2026-05-14T07:21:26.810787Z",
- "iopub.status.idle": "2026-05-14T07:21:29.003649Z",
- "shell.execute_reply": "2026-05-14T07:21:29.003242Z"
+ "iopub.execute_input": "2026-06-29T09:22:16.294928Z",
+ "iopub.status.busy": "2026-06-29T09:22:16.294687Z",
+ "iopub.status.idle": "2026-06-29T09:22:18.823631Z",
+ "shell.execute_reply": "2026-06-29T09:22:18.822946Z"
}
},
"outputs": [
@@ -449,7 +1011,7 @@
"skill_comp"
],
"y": {
- "bdata": "AAAAAAAA8D/DIYazpav2Pyn68owbpfU/",
+ "bdata": "AAAAAAAA8D+mai3Q7wT5PwWoimC57fc/",
"dtype": "f8"
}
},
@@ -462,20 +1024,7 @@
"skill_comp"
],
"y": {
- "bdata": "AAAAAAAA8D9fbAPZXff1P7wQMLJ+EvY/",
- "dtype": "f8"
- }
- },
- {
- "name": "AMN",
- "type": "bar",
- "x": [
- "skill_math",
- "skill_recog",
- "skill_comp"
- ],
- "y": {
- "bdata": "AAAAAAAA8D+wS/CPTL/yP0IjsiJDuvI/",
+ "bdata": "AAAAAAAA8D/RnhnsWgL2PwusBrQN9/U/",
"dtype": "f8"
}
}
@@ -558,7 +1107,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -594,7 +1143,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -618,7 +1167,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -654,7 +1203,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -681,7 +1230,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -717,7 +1266,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -732,7 +1281,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -768,7 +1317,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -924,7 +1473,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -960,7 +1509,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -1051,7 +1600,7 @@
],
"sequential": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1087,13 +1636,13 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
"sequentialminus": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1129,7 +1678,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
]
@@ -1259,7 +1808,7 @@
}
},
"title": {
- "text": "Period-0 skill loadings across estimators"
+ "text": "Period-0 skill loadings: CHS vs AF"
},
"yaxis": {
"title": {
@@ -1275,11 +1824,11 @@
],
"source": [
"fig = go.Figure()\n",
- "for est in (\"CHS\", \"AF\", \"AMN\"):\n",
+ "for est in (\"CHS\", \"AF\"):\n",
" sub = loadings_df[loadings_df[\"estimator\"] == est]\n",
" fig.add_trace(go.Bar(name=est, x=sub[\"measurement\"], y=sub[\"loading\"]))\n",
"fig.update_layout(\n",
- " title=\"Period-0 skill loadings across estimators\",\n",
+ " title=\"Period-0 skill loadings: CHS vs AF\",\n",
" barmode=\"group\",\n",
" template=\"plotly_white\",\n",
" yaxis_title=\"loading\",\n",
@@ -1289,14 +1838,14 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"id": "18",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:21:29.004727Z",
- "iopub.status.busy": "2026-05-14T07:21:29.004623Z",
- "iopub.status.idle": "2026-05-14T07:21:29.011824Z",
- "shell.execute_reply": "2026-05-14T07:21:29.011388Z"
+ "iopub.execute_input": "2026-06-29T09:22:18.826321Z",
+ "iopub.status.busy": "2026-06-29T09:22:18.825917Z",
+ "iopub.status.idle": "2026-06-29T09:22:18.838705Z",
+ "shell.execute_reply": "2026-06-29T09:22:18.837944Z"
}
},
"outputs": [
@@ -1335,21 +1884,15 @@
" \n",
" \n",
" | CHS | \n",
- " 0.893 | \n",
- " 0.000 | \n",
- " 1.337 | \n",
+ " 0.818 | \n",
+ " 0.026 | \n",
+ " -0.145 | \n",
"
\n",
" \n",
" | AF | \n",
- " 0.683 | \n",
- " 0.271 | \n",
- " -1.178 | \n",
- "
\n",
- " \n",
- " | AMN | \n",
- " 0.894 | \n",
- " 0.000 | \n",
- " 0.457 | \n",
+ " 0.792 | \n",
+ " 0.118 | \n",
+ " -1.063 | \n",
"
\n",
" \n",
"\n",
@@ -1358,12 +1901,11 @@
"text/plain": [
" gamma_skills gamma_investment phi\n",
"estimator \n",
- "CHS 0.893 0.000 1.337\n",
- "AF 0.683 0.271 -1.178\n",
- "AMN 0.894 0.000 0.457"
+ "CHS 0.818 0.026 -0.145\n",
+ "AF 0.792 0.118 -1.063"
]
},
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -1388,14 +1930,14 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"id": "19",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-05-14T07:21:29.012809Z",
- "iopub.status.busy": "2026-05-14T07:21:29.012715Z",
- "iopub.status.idle": "2026-05-14T07:21:29.034918Z",
- "shell.execute_reply": "2026-05-14T07:21:29.034535Z"
+ "iopub.execute_input": "2026-06-29T09:22:18.840717Z",
+ "iopub.status.busy": "2026-06-29T09:22:18.840521Z",
+ "iopub.status.idle": "2026-06-29T09:22:18.881311Z",
+ "shell.execute_reply": "2026-06-29T09:22:18.880444Z"
}
},
"outputs": [
@@ -1414,28 +1956,11 @@
"name": "AF vs CHS",
"type": "scatter",
"x": {
- "bdata": "VYuwyapJwj8gz+Pl1hDFP6FHVF1E08E/fOy9wklt5j8ek2zX6YPmP/tmS6+3d+U/0w3gSUi94D+c1rhUBG/bP1kVmZ0hnuU/HTrXCtZl3D+X/WuijgXbP1v1EI2/YLY/LhzMiSiK6T9F2GxXZFHbPwqj62I8KOw/ZpAJHzRtvz9bFq9nRbjCP0pk1hjuxsA/CRuL2MRi5D9u3Qm/mOfdP+1zpbocEek/Vr3pBiag4z9RXCjpI9HdPxyn66by8uc/",
+ "bdata": "WW7/C5EW5j+ARTDW8J/mP4h4UM78n+U/rMhQdUPT4D8S0w+4PKXbP3YpLcXdieU/ANkXfV4C3j9poQ4OpznbPw5BgIpXi7U/qUjzm9aM6z9L0qwiwMjeP+Z7X8puYO0/UwoEU+sV5D8tW1RFPlPgPwB2VOA65uc/",
"dtype": "f8"
},
"y": {
- "bdata": "/kmPhxceuz+FzaYg15+3PwI8WSPocLY/4ioSrKlH5T+eTVdFpKrmP5IR6kvQ6+U/vLiFAlCX4D9OiwwoQFXbP5PK4PeP1OQ/7RECqMrj1j93mPTaN/bVP+xFKBFmQLI/4xH0KIZK6j/SNCKMT8raP7uJ2SXqBeo/PBHp2xAptz8YwU9Xy7m0PwXhcSSVW7I/chUV98084j+Yx0VYEKLcP3HOoulxzuU/eRyhAt3X5D84zKh3/3vlP3ja63vySeg/",
- "dtype": "f8"
- }
- },
- {
- "marker": {
- "size": 8,
- "symbol": "diamond"
- },
- "mode": "markers",
- "name": "AMN vs CHS",
- "type": "scatter",
- "x": {
- "bdata": "VYuwyapJwj8gz+Pl1hDFP6FHVF1E08E/fOy9wklt5j8ek2zX6YPmP/tmS6+3d+U/0w3gSUi94D+c1rhUBG/bP1kVmZ0hnuU/HTrXCtZl3D+X/WuijgXbP1v1EI2/YLY/LhzMiSiK6T9F2GxXZFHbPwqj62I8KOw/ZpAJHzRtvz9bFq9nRbjCP0pk1hjuxsA/CRuL2MRi5D9u3Qm/mOfdP+1zpbocEek/Vr3pBiag4z9RXCjpI9HdPxyn66by8uc/",
- "dtype": "f8"
- },
- "y": {
- "bdata": "ID4Z1si/zT8tg8+KHOS+P0xKjbQvCcw/AIgHCYvx4z+0QdDWLtDnP/GDYl+yL+Y/anh/+Bbj4j/5y1gmMmvfP8FGlWE7muM/HfVRvmhm3j97iD20ELnUP8R5Sw5wfbs/xNGfp2iU5z+p23XHJUzhP7X8nQPNJug/B1DCbD5bwT/yRhQLpM++PzzYqbveNcU/zG1Y9vFc5D8N4y3kTYXiP9EGbUQFC+Y/TOPLhSZj5D95JK2X26ThP5WinuhjnOU/",
+ "bdata": "OJLE5dad5T+TptDH9KjmP7+a36k86eU/KDyWztL/4D+i7FeJDtfbP1qXYwviNOU/8RpIRBbZ2T8qw3D75ZzZP2lzoUMuBbU/muHrcEAb6j9iNjwW883aP+6VVfQkOOo/ecqS7pqV4z9FPpEejh3YP7beDHPsPec/",
"dtype": "f8"
}
},
@@ -1448,12 +1973,12 @@
"showlegend": false,
"type": "scatter",
"x": [
- 0.07129514616851623,
- 0.8799116069832007
+ 0.08211030150546396,
+ 0.9180215790364599
],
"y": [
- 0.07129514616851623,
- 0.8799116069832007
+ 0.08211030150546396,
+ 0.9180215790364599
]
}
],
@@ -1535,7 +2060,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1571,7 +2096,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -1595,7 +2120,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1631,7 +2156,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -1658,7 +2183,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1694,7 +2219,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -1709,7 +2234,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1745,7 +2270,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -1901,7 +2426,7 @@
},
"colorscale": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -1937,7 +2462,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
@@ -2028,7 +2553,7 @@
],
"sequential": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -2064,13 +2589,13 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
],
"sequentialminus": [
[
- 0,
+ 0.0,
"#0d0887"
],
[
@@ -2106,7 +2631,7 @@
"#fdca26"
],
[
- 1,
+ 1.0,
"#f0f921"
]
]
@@ -2236,7 +2761,7 @@
}
},
"title": {
- "text": "Signal share by measurement (CHS reference vs AF/AMN)"
+ "text": "Signal share by measurement (CHS reference vs AF)"
},
"xaxis": {
"title": {
@@ -2245,7 +2770,7 @@
},
"yaxis": {
"title": {
- "text": "signal share (AF / AMN)"
+ "text": "signal share (AF)"
}
}
}
@@ -2256,7 +2781,7 @@
}
],
"source": [
- "decomp_by_est = {\"CHS\": chs_decomp, \"AF\": af_decomp, \"AMN\": amn_decomp}\n",
+ "decomp_by_est = {\"CHS\": chs_decomp, \"AF\": af_decomp}\n",
"signal_share = pd.DataFrame(\n",
" {est: d[\"fraction_signal\"] for est, d in decomp_by_est.items()}\n",
")\n",
@@ -2271,15 +2796,6 @@
" marker={\"size\": 8},\n",
" )\n",
")\n",
- "fig.add_trace(\n",
- " go.Scatter(\n",
- " x=signal_share[\"CHS\"],\n",
- " y=signal_share[\"AMN\"],\n",
- " mode=\"markers\",\n",
- " name=\"AMN vs CHS\",\n",
- " marker={\"size\": 8, \"symbol\": \"diamond\"},\n",
- " )\n",
- ")\n",
"lo = float(signal_share.min().min())\n",
"hi = float(signal_share.max().max())\n",
"fig.add_trace(\n",
@@ -2292,10 +2808,10 @@
" )\n",
")\n",
"fig.update_layout(\n",
- " title=\"Signal share by measurement (CHS reference vs AF/AMN)\",\n",
+ " title=\"Signal share by measurement (CHS reference vs AF)\",\n",
" template=\"plotly_white\",\n",
" xaxis_title=\"signal share (CHS)\",\n",
- " yaxis_title=\"signal share (AF / AMN)\",\n",
+ " yaxis_title=\"signal share (AF)\",\n",
" height=520,\n",
")\n",
"fig"
@@ -2305,7 +2821,15 @@
"cell_type": "markdown",
"id": "20",
"metadata": {},
- "source": "## Next steps\n\n- [How to estimate AF](../how_to_guides/how_to_estimate_af.md) covers the AF API in depth.\n- [How to estimate AMN](../how_to_guides/how_to_estimate_amn.md) shows AMN on its native ground (a synthetic 2-mixture DGP where the mixture-of-normals advantage is visible).\n- [How to compare estimators](../how_to_guides/how_to_compare_estimators.md) extends this tutorial with 95% confidence intervals (CHS analytic sandwich; AF and AMN cluster bootstrap) and overlaid posterior-factor trajectories.\n- [Architecture](../explanations/architecture.md) maps the `common/chs/af/amn` subpackage layout."
+ "source": [
+ "## Next steps\n",
+ "\n",
+ "- [How to estimate AF](../how_to_guides/how_to_estimate_af.md) covers the AF API in depth.\n",
+ "- [How to estimate AMN](../how_to_guides/how_to_estimate_amn.md) shows AMN on its native ground (a synthetic 2-mixture DGP where the mixture-of-normals advantage is visible).\n",
+ "- [How to compare estimators](../how_to_guides/how_to_compare_estimators.md) extends this tutorial with 95% confidence intervals (CHS analytic OPG / inverse-score; AF a propagated influence-function score bootstrap; AMN a cluster bootstrap) and overlaid posterior-factor trajectories.\n",
+ "- [Estimator prerequisites](../reference_guides/estimator_prerequisites.md) compares what data features and model constructs each estimator supports.\n",
+ "- [Architecture](../explanations/architecture.md) maps the `common/chs/af/amn` subpackage layout."
+ ]
}
],
"metadata": {
@@ -2322,7 +2846,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.14.3"
+ "version": "3.14.6"
}
},
"nbformat": 4,
diff --git a/docs/how_to_guides/how_to_compare_estimators.md b/docs/how_to_guides/how_to_compare_estimators.md
index b79b829b..89411c2e 100644
--- a/docs/how_to_guides/how_to_compare_estimators.md
+++ b/docs/how_to_guides/how_to_compare_estimators.md
@@ -18,10 +18,18 @@ The guide assumes the three estimation results from the tutorial are in scope:
`chs_result`, `af_result`, and `amn_result`. The corresponding model and data
fixtures (`model`, `data`) are the same across all three.
+The three estimators support different data features and model constructs; see
+[Estimator Prerequisites](../reference_guides/estimator_prerequisites.md) for the
+support matrix (measurement families, missing data, corrections, anchoring).
+
## Why each estimator gets a different inference
-Each estimator computes the same point estimate of the same model, but the
-sampling-distribution machinery differs:
+The three estimators can be applied to the same structural `ModelSpec`, but they
+use different estimating criteria — CHS a joint Gaussian-component likelihood, AF
+a sequential Halton-integrated likelihood, AMN a three-stage
+EM / minimum-distance / simulate-and-regress fit — so their point estimates may
+differ even before inference is considered. The sampling-distribution machinery
+also differs:
| Estimator | Inference | Why this and not bootstrap (CHS) / not OPG (AF, AMN) |
| --------- | ---------------------------------------------------- | --------------------------------------------------------- |
diff --git a/docs/how_to_guides/how_to_estimate_af.md b/docs/how_to_guides/how_to_estimate_af.md
index c7a3b003..07708e76 100644
--- a/docs/how_to_guides/how_to_estimate_af.md
+++ b/docs/how_to_guides/how_to_estimate_af.md
@@ -91,10 +91,14 @@ af_options = AFEstimationOptions(
)
```
-All optimagic constraint kinds are supported: `FixedConstraintWithValue`
-(from normalisations / `fixed_params`), `ProbabilityConstraint` (from
-`log_ces` `gamma` simplex), and `EqualityConstraint` (within-step and
-cross-period equalities passed through `estimate_af(constraints=...)`).
+AF internally creates the fixed and probability constraints implied by the
+processed model: `FixedConstraintWithValue` (from normalisations / `fixed_params`)
+and `ProbabilityConstraint` (from the `log_ces` `gamma` simplex). The public
+`constraints=` argument is narrower — it honours only `om.EqualityConstraint`
+groups whose selector is built with
+`skillmodels.common.constraints.select_by_loc` (used for within-step and
+cross-period equality restrictions). Other optimagic constraint kinds passed there
+are ignored; supply any further fixed values through `fixed_params` instead.
## Start-values strategy
diff --git a/docs/how_to_guides/how_to_estimate_amn.md b/docs/how_to_guides/how_to_estimate_amn.md
index 4cb22f25..c84ca977 100644
--- a/docs/how_to_guides/how_to_estimate_amn.md
+++ b/docs/how_to_guides/how_to_estimate_amn.md
@@ -15,11 +15,13 @@ three stages:
from the fitted mixture and estimate the production function by
regression on the synthetic data.
-AMN shines when the data are non-Gaussian in the latent factor distribution.
-CHS assumes Gaussian latent factors (one mixture component); AF supports
-multiple mixture components but fits them jointly with the period-specific
-optimizer; AMN cleanly separates the mixture from the structural recovery and
-explicitly models the non-Gaussianity through the EM step.
+AMN shines when the latent factor distribution is non-Gaussian. CHS supports a
+finite mixture of Gaussian initial states via `ModelSpec.n_mixtures`, but filters
+each component with a Gaussian (square-root Kalman) recursion; setting
+`n_mixtures=1` is the deliberately restricted single-Gaussian benchmark. AF also
+supports multiple mixture components but fits them jointly with the
+period-specific optimizer. AMN cleanly separates the mixture (Stage-1 EM) from
+the structural recovery, so it models the non-Gaussianity explicitly.
## Minimal example
@@ -62,8 +64,9 @@ result.success # AND across stage convergence flags
The smallest example that lets AMN's non-Gaussian fit show its advantage is a
1-factor / 3-period model where the latent skill is drawn from a non-trivial
-mixture-of-normals. CHS, restricted to Gaussian latents, produces biased
-production-function estimates on this DGP; AMN's Stage 1 EM recovers the
+mixture-of-normals. CHS with a single Gaussian component (`n_mixtures=1`)
+produces biased production-function estimates on this DGP; AMN's Stage 1 EM
+recovers the
mixture and the structural step undoes the bias.
```python
@@ -140,9 +143,9 @@ result.stages.mixture.weights # tau, should be near (0.6, 0.4) up to label
result.stages.mixture.means # Pi_k for the augmented measure vector
```
-Compare against a CHS fit of the same model (1 mixture component, since CHS
-assumes Gaussian latents) and verify that the slope estimate from CHS is
-biased downward — that's the signal AMN was designed to capture.
+Compare against a CHS fit of the same model with `n_mixtures=1` (the deliberately
+restricted single-Gaussian benchmark) and verify that the slope estimate from CHS
+is biased downward — that's the signal AMN was designed to capture.
## Tuning knobs
@@ -163,6 +166,22 @@ models; if the EM warns about degenerate covariances, bump `em_reg_covar` to
loadings, then projected back to the augmented-measure space; that
data-driven start beats random init by a wide margin.
+### Stage-1 missing data
+
+`mixture_em_method` selects how Stage 1 handles incomplete measurement rows:
+
+- `"complete_case"` (default) fits `sklearn.mixture.GaussianMixture` on
+ listwise-complete rows and raises `InsufficientCompleteCasesError` when fewer
+ than `n_mixtures` rows are complete.
+- `"missing_data"` fits an EM that marginalises over each row's missing entries,
+ valid under an ignorable (MAR) missingness assumption even when no row is
+ complete.
+
+A column that is never observed in any sampled row has unidentified moments; the
+missing-data EM raises unless `allow_never_observed_measurements=True` is set
+(intended only for seeding-style uses where such a column is tolerated).
+`mixture_em_max_rows` optionally subsamples the EM input for speed.
+
### Stage-2 weighting
`minimum_distance_weighting="identity"` (the paper's default, and currently the
@@ -250,10 +269,15 @@ factor's production regression. Under the correction:
- **Anchoring** is not wired through the AMN stages. The model spec's
`AnchoringSpec` is accepted (so the spec stays compatible with CHS), but
the AMN result reports unanchored factor scales.
-- **Within-stage user constraints.** `estimate_amn(constraints=...)` is a
- pass-through hook for forward compatibility; the AMN stages do not yet
- honour `om.EqualityConstraint`. User `fixed_params` are applied
- post-hoc to the combined params DataFrame.
+- **`start_params` and `constraints`.** `estimate_amn` does not honour either
+ and raises `NotImplementedError` if you pass them — the three-stage pipeline
+ has no single parameter vector to seed or constrain.
+- **`fixed_params`** are honoured only for the categories each stage estimates,
+ and pinned inside that stage rather than post-hoc: Stage 2 honours `loadings`,
+ measurement intercepts, and measurement SDs; Stage 3 honours `transition`.
+ Passing any other category raises `NotImplementedError`. (The generic Stage-3
+ NLS path supports pinning for most transitions, but `log_ces` /
+ `log_ces_with_constant` reject a non-empty fixed set.)
See [How to compare estimators](how_to_compare_estimators.md) for an
overlay of CHS, AF, and AMN on the same data with confidence intervals.
diff --git a/docs/how_to_guides/model_specs.md b/docs/how_to_guides/model_specs.md
index f6d7ef70..e3d302e9 100644
--- a/docs/how_to_guides/model_specs.md
+++ b/docs/how_to_guides/model_specs.md
@@ -78,11 +78,15 @@ Each factor requires:
`estimate_af` emits a `UserWarning` if you use a general built-in transition on
a production factor while observed factors are present.
- **normalizations** (optional): Fixed values for loadings and intercepts to identify
- the model.
+ the model. The model checker validates these syntactically but does not prove
+ transition-specific identification; see
+ [Notes on factor scales](../explanations/notes_on_factor_scales.md).
- **is_endogenous** (optional): Whether this factor is endogenous (default: false).
See [Endogeneity Corrections](../reference_guides/endogeneity_corrections.md).
-- **is_correction** (optional): Whether this is a correction factor (default: false).
- Must also be endogenous.
+- **correction** (optional): A `CorrectionSpec | None` attached to an endogenous
+ investment factor, adding a control-function correction for investment
+ endogeneity. See
+ [Endogeneity Corrections](../reference_guides/endogeneity_corrections.md).
## Anchoring
@@ -124,11 +128,14 @@ model = ModelSpec(
## Estimation Options
-Fine-tune the estimation:
+`n_mixtures` is a structural field on `ModelSpec` itself — the number of components
+in the latent-factor mixture (default 1). The numerical knobs below are
+**CHS-specific** and live on `CHSEstimationOptions` (`skillmodels.chs`), not on
+`ModelSpec`; AF and AMN have their own option dataclasses (`AFEstimationOptions`,
+`AMNEstimationOptions`).
- **robust_bounds**: Make bounds stricter to avoid numerical issues (default: true)
- **bounds_distance**: How much stricter to make bounds (default: 0.001)
-- **n_mixtures**: Number of mixture components (default: 1)
- **sigma_points_scale**: Scaling for Julier sigma points (default: 2)
- **clipping_lower_bound**: Clip log-likelihood from below (default: -1e30)
- **clipping_upper_bound**: Clip log-likelihood from above (default: null)
diff --git a/docs/index.md b/docs/index.md
index 5ba41a65..9c07d585 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -17,6 +17,11 @@ nonlinear latent-factor model. Key features:
period-by-period.
- `amn` — three-stage mixture-of-normals (Attanasio, Meghir & Nix 2020):
EM, minimum distance, simulated regression.
+
+ The three estimators do **not** support the same data features and model
+ constructs (measurement families, missing data, corrections, anchoring). See
+ [Estimator Prerequisites](reference_guides/estimator_prerequisites.md) for the
+ full comparison before choosing one.
- **Strongly-typed, immutable model spec**: frozen dataclasses with
`MappingProxyType` containers throughout.
- **JAX everywhere**: jitted likelihoods, autodiff gradients, optional GPU.
diff --git a/docs/myst.yml b/docs/myst.yml
index 34ff08a8..f936dea8 100644
--- a/docs/myst.yml
+++ b/docs/myst.yml
@@ -47,6 +47,7 @@ project:
- file: explanations/linear_predict.md
- title: Reference Guides
children:
+ - file: reference_guides/estimator_prerequisites.md
- file: reference_guides/transition_functions.md
- file: reference_guides/endogeneity_corrections.md
error_rules:
diff --git a/docs/reference_guides/estimator_prerequisites.md b/docs/reference_guides/estimator_prerequisites.md
new file mode 100644
index 00000000..a6e50ec3
--- /dev/null
+++ b/docs/reference_guides/estimator_prerequisites.md
@@ -0,0 +1,46 @@
+# Estimator Prerequisites — When to Use Which
+
+skillmodels ships three estimators for the same family of nonlinear dynamic latent
+factor models — CHS, AF, and AMN. They accept the same structural `ModelSpec`, but they
+differ in what data features and model constructs they actually support. Pick an
+estimator by checking your model against the prerequisites below **before** estimating;
+a feature one estimator handles natively may be rejected, ignored, or silently
+restricted by another.
+
+A statement that a model uses, say, `ModelSpec.measurement_models` does **not** imply
+uniform support: measurement-family handling in particular differs sharply across the
+three estimators.
+
+## Prerequisites matrix
+
+| Dimension | CHS | AF | AMN |
+| --------- | --- | -- | --- |
+| Estimator | Joint maximum likelihood with square-root Gaussian-component filtering. | Sequential period-by-period likelihood with Halton integration. | Three stages: Gaussian-mixture EM → minimum distance → simulate-and-regress. |
+| Latent distribution | Finite mixture of Gaussian initial states via `ModelSpec.n_mixtures`; Gaussian component filtering thereafter. | Initial finite mixture plus sequentially carried distributions; Halton nodes approximate the integrals. | Stage-1 Gaussian mixture over the augmented measurement vector; `n_mixtures` sets the component count. |
+| Measurement families | **Gaussian only** (standard Kalman update). Probit/Tobit measurements are not consumed by the CHS path. | **Initial period only**: probit/Tobit measurement families are honoured at the period-0 measurement system; transition periods ($t \geq 1$) fall back to an all-Gaussian measurement kernel. | **Gaussian only**: `estimate_amn` raises `NotImplementedError` if `ModelSpec.measurement_models` declares any probit/Tobit measurement. |
+| Missing data | Gaussian measurement updates skip individually missing measurements. | Measurement masks skip missing rows in the per-step likelihood contributions. | `mixture_em_method="complete_case"` (default) or `"missing_data"` (marginalises over missing entries under MAR); never-observed columns require `allow_never_observed_measurements=True`. |
+| Endogenous investment | Reconstructed endogenous factors plus a full `CorrectionSpec` control-function basis. | Reconstructed endogenous investment with independent shocks (source/destination calendar adapter); rejects `CorrectionSpec` / nonzero `kappa`. | Endogenous investment in Stage 3; correction support is a linearised control-function term, narrower than CHS. |
+| Corrections | Full `CorrectionSpec` / `kappa` polynomial basis in the processed transition DAG. | Not implemented — the validator raises if a `CorrectionSpec` is attached. | Linear `cf` term only; a higher-order `kappa_terms` request raises. |
+| Custom transitions | Built-ins and `@register_params` callables. | Built-ins and `@register_params` callables. | Built-ins and `@register_params` callables via the Stage-3 generic NLS path, but `log_ces` / `log_ces_with_constant` reject fixed parameters there. |
+| `fixed_params` | Honoured through the shared parameter index. | Honoured (with public `constraints=` limited to `select_by_loc` equality groups). | Honoured only for the categories each stage estimates (Stage-2 loadings/intercepts/SDs, Stage-3 `transition`); other categories raise. `start_params` and `constraints` raise. |
+| Normalization | Shared `Normalizations` / `fixed_params` / equality constraints. The checker is a precheck, not an identification proof. | Same public `ModelSpec`, with the AF-specific source/destination calendar and `af_state_role` metadata. | Stage-2 imposes its own structural moment restrictions and a mean-zero mixture convention. |
+| Anchoring | Supported through `get_maximization_inputs` / the CHS path. | Not part of the AF likelihood; downstream visualization only. | Not wired through the AMN stages (the result reports unanchored scales). |
+| Cost / scaling | Potentially expensive joint ML; the JAX square-root filter helps numerical stability. | Sequential but quadrature-heavy; cost grows with node count and state dimension. | Fast when the stages are well behaved; the cluster bootstrap is expensive because it re-estimates every stage. |
+| Use it for | Likelihood benchmark, anchoring, correction-heavy models, Gaussian measurement systems. | Sequential AF-style models, initial-period limited measurements, period-by-period diagnostics. | Mixture-heavy Gaussian-measurement models, fast start values for CHS, Stage-1/2 structural diagnostics. |
+
+## Reading the matrix
+
+- **Measurement families** are the sharpest divide. If any measurement is probit or
+ Tobit, AMN rejects the model outright, CHS ignores the family and treats it as
+ Gaussian, and AF honours it only at the initial period. Only models with
+ initial-period-only limited measurements have genuine (partial) support, via AF.
+- **Endogeneity corrections** flow through `FactorSpec.correction` (a `CorrectionSpec`).
+ CHS reads the full polynomial `kappa` basis; AMN keeps only the linear term; AF
+ rejects the correction entirely (its endogenous-investment handling is the
+ reconstructed-factor calendar adapter, not a control function).
+- **`start_params` / `constraints`** are CHS/AF concepts; AMN raises on both.
+
+See [How to compare estimators](../how_to_guides/how_to_compare_estimators.md) for an
+overlay of all three on the same data with confidence intervals, and
+[Endogeneity Corrections](endogeneity_corrections.md) for the `CorrectionSpec`
+interface.
diff --git a/docs/reference_guides/transition_functions.md b/docs/reference_guides/transition_functions.md
index b8b67ec5..488d9461 100644
--- a/docs/reference_guides/transition_functions.md
+++ b/docs/reference_guides/transition_functions.md
@@ -5,8 +5,11 @@ several pre-built functions and supports custom functions.
The same transition functions work for all three estimators (CHS, AF, AMN) — they live
in `skillmodels.common.transition_functions` and are dispatched by name through each
-estimator's pipeline. AMN's Stage 3 currently supports the pre-built set listed below;
-custom `@register_params` transitions work with CHS and AF but not yet with AMN.
+estimator's pipeline. CHS and AF support both the pre-built set and custom
+`@register_params` transitions. AMN also supports custom callables through its Stage-3
+generic nonlinear-least-squares path, but with narrower correction- and
+fixed-parameter support than CHS (for example, `log_ces` / `log_ces_with_constant`
+reject fixed parameters in that path).
## Pre-built Transition Functions
diff --git a/src/skillmodels/__init__.py b/src/skillmodels/__init__.py
index e012af61..e470120f 100644
--- a/src/skillmodels/__init__.py
+++ b/src/skillmodels/__init__.py
@@ -35,6 +35,11 @@
get_individual_states,
get_individual_states_from_params,
)
+from skillmodels.common.measurement_models import ( # noqa: E402
+ GaussianMeasurement,
+ ProbitMeasurement,
+ TobitMeasurement,
+)
from skillmodels.common.model_spec import ( # noqa: E402
AnchoringSpec,
CorrectionSpec,
@@ -55,8 +60,11 @@
"CommonEstimationResult",
"CorrectionSpec",
"FactorSpec",
+ "GaussianMeasurement",
"ModelSpec",
"Normalizations",
+ "ProbitMeasurement",
+ "TobitMeasurement",
"estimate_af",
"estimate_amn",
"estimate_chs",
diff --git a/src/skillmodels/af/estimate.py b/src/skillmodels/af/estimate.py
index c219776a..3d69b555 100644
--- a/src/skillmodels/af/estimate.py
+++ b/src/skillmodels/af/estimate.py
@@ -1,5 +1,7 @@
"""Main driver for the AF estimation procedure."""
+from typing import cast
+
import jax
import jax.numpy as jnp
import numpy as np
@@ -11,6 +13,10 @@
from skillmodels._beartype_conf import ESTIMATION_CONF
from skillmodels.af.initial_period import estimate_initial_period
from skillmodels.af.params import get_measurements_per_factor
+from skillmodels.af.step_layout import (
+ HistoricalParams,
+ fail_if_spearman_unsupported_on_adapter,
+)
from skillmodels.af.transition_period import estimate_transition_period
from skillmodels.af.types import (
AFEstimationOptions,
@@ -101,8 +107,7 @@ def estimate_af(
options = AFEstimationOptions()
af_options = options
- validate_af_model(model_spec, fixed_params, constraints)
- fail_if_unsupported_kappa_params(start_params, fixed_params, constraints)
+ _validate_af_inputs(model_spec, fixed_params, constraints, start_params, af_options)
processed_model = process_model(model_spec)
# If AMN-based starts are requested, run the full AMN three-stage
@@ -140,6 +145,7 @@ def estimate_af(
n_halton_points_posterior_summary=(
af_options.n_halton_points_posterior_summary
),
+ bounds_distance=af_options.bounds_distance,
)
# Extract data arrays per period
@@ -157,6 +163,13 @@ def estimate_af(
)
state_factors = tuple(f for f in factors if f not in endogenous_factors)
+ # Align the full per-observation payload to one ID order on the adapter path, so the
+ # mixed-calendar measurement rows pair with the right individual's controls and
+ # latent payload (and reject an unbalanced panel that AF cannot positionally align).
+ data = _maybe_align_adapter_panel(
+ data, n_periods, model_spec, endogenous_factors, state_factors
+ )
+
period_data = _extract_period_data(
data,
n_periods,
@@ -166,6 +179,8 @@ def estimate_af(
observed_factors=observed_factors,
)
+ frames_by_period = _frames_by_period_from_data(data, n_periods)
+
equality_groups = _extract_equality_groups(constraints)
step_constraints = constraints
@@ -199,6 +214,19 @@ def estimate_af(
break
prev_period_params = period_results[-1].params
+ # The cumulative registry feeds the calendar adapter's fixed importance
+ # block (source skills + static-persistent period-0 rows). Build it whenever
+ # the layout path is active (reconstructed investment or static factors).
+ needs_historical = any(
+ (spec.is_endogenous and not spec.has_initial_distribution)
+ or spec.af_state_role == "static_persistent"
+ for spec in model_spec.factors.values()
+ )
+ historical = (
+ HistoricalParams.from_param_frames([r.params for r in period_results])
+ if needs_historical
+ else None
+ )
period_t_result, cond_dist = estimate_transition_period(
period=t,
@@ -210,6 +238,8 @@ def estimate_af(
prev_controls=period_data[t - 1]["controls"],
prev_period_params=prev_period_params,
prev_distribution=cond_dist,
+ frames_by_period=frames_by_period,
+ historical=historical,
af_options=af_options,
endogenous_factors=endogenous_factors,
observed_factors=observed_factors,
@@ -242,14 +272,102 @@ def estimate_af(
else:
conditional_dists_out = ()
+ period_means = tuple(float(r.loglikelihood) for r in period_results)
return AFEstimationResult(
period_results=tuple(period_results),
params=all_params,
model_spec=model_spec,
conditional_distributions=conditional_dists_out,
success=all(r.success for r in period_results),
- loglikelihood=float(sum(r.loglikelihood for r in period_results)),
+ loglikelihood=float(sum(period_means)),
+ sequential_criterion=float(sum(period_means)),
+ period_mean_criteria=period_means,
+ )
+
+
+def _frames_by_period_from_data(
+ data: pd.DataFrame, n_periods: int
+) -> dict[int, pd.DataFrame]:
+ """Build per-period individual-ID-indexed frames for the calendar adapter.
+
+ The adapter pairs each individual's destination-skill and source-investment rows by
+ ID across periods, so the frames are kept ID-indexed rather than positionally
+ concatenated. Level 1 is the period level (level 0 is the individual ID).
+ """
+ available_periods = set(data.index.get_level_values(1))
+ return {
+ p: cast("pd.DataFrame", data.xs(p, level=1))
+ for p in range(n_periods)
+ if p in available_periods
+ }
+
+
+def _align_adapter_panel(data: pd.DataFrame, n_periods: int) -> pd.DataFrame:
+ """Canonicalise a balanced panel to one ID-sorted order for the adapter path.
+
+ The calendar adapter sources mixed-calendar measurements on a sorted individual-ID
+ intersection, while controls, observed factors, the period-0 conditional
+ distribution, and the chain-link payloads are consumed positionally in input-row
+ order. Sorting by `(id, period)` puts every per-observation array in the same ID
+ order so the measurement row for an individual is paired with that same
+ individual's controls and latent payload.
+
+ AF aligns periods positionally, so the panel must be balanced: every individual
+ must appear in every period. Raise otherwise.
+ """
+ period_level = str(data.index.names[1])
+ period_values = data.index.get_level_values(period_level)
+ present = [p for p in range(n_periods) if bool((period_values == p).any())]
+ id_sets = {p: set(data.xs(p, level=period_level).index) for p in present}
+ reference = id_sets[present[0]]
+ for p in present[1:]:
+ if id_sets[p] != reference:
+ diff = sorted(reference.symmetric_difference(id_sets[p]))
+ msg = (
+ "AF calendar adapter requires a balanced panel: the individual set in "
+ f"period {p} differs from period {present[0]} (e.g. {diff[:5]}). "
+ "Provide every individual in every period."
+ )
+ raise ValueError(msg)
+ return data.sort_index()
+
+
+def _validate_af_inputs(
+ model_spec: ModelSpec,
+ fixed_params: pd.DataFrame | None,
+ constraints: list[om.constraints.Constraint] | None,
+ start_params: pd.DataFrame | None,
+ af_options: AFEstimationOptions,
+) -> None:
+ """Reject unsupported AF input combinations before any fitting begins."""
+ fail_if_spearman_unsupported_on_adapter(
+ model_spec, af_options.start_params_strategy
+ )
+ validate_af_model(model_spec, fixed_params, constraints)
+ fail_if_unsupported_kappa_params(start_params, fixed_params, constraints)
+
+
+def _maybe_align_adapter_panel(
+ data: pd.DataFrame,
+ n_periods: int,
+ model_spec: ModelSpec,
+ endogenous_factors: tuple[str, ...],
+ state_factors: tuple[str, ...],
+) -> pd.DataFrame:
+ """Canonicalise the panel when the source/destination calendar adapter is active.
+
+ The adapter path is taken for a reconstructed endogenous factor
+ (`has_initial_distribution=False`) or any static-persistent state factor. Plain AF
+ models are returned unchanged.
+ """
+ reconstructed_endog = any(
+ not model_spec.factors[f].has_initial_distribution for f in endogenous_factors
+ )
+ uses_layout = reconstructed_endog or any(
+ model_spec.factors[f].af_state_role == "static_persistent"
+ for f in state_factors
)
+ return _align_adapter_panel(data, n_periods) if uses_layout else data
def _extract_period_data(
diff --git a/src/skillmodels/af/inference.py b/src/skillmodels/af/inference.py
index 15481ef1..7b2ec7f5 100644
--- a/src/skillmodels/af/inference.py
+++ b/src/skillmodels/af/inference.py
@@ -65,6 +65,7 @@
build_optimagic_inputs,
get_measurements_per_factor,
)
+from skillmodels.af.step_layout import fail_if_calendar_adapter_unsupported
from skillmodels.af.transition_period import (
_extract_prev_measurement_params,
_get_raw_transition_functions,
@@ -180,6 +181,9 @@ def compute_af_standard_errors(
replicate-by-parameter DataFrame.
"""
+ fail_if_calendar_adapter_unsupported(
+ result.model_spec, "AF standard-error inference"
+ )
if af_options is None:
af_options = AFEstimationOptions()
diff --git a/src/skillmodels/af/initial_period.py b/src/skillmodels/af/initial_period.py
index 3dee3fa0..9ecc1f4b 100644
--- a/src/skillmodels/af/initial_period.py
+++ b/src/skillmodels/af/initial_period.py
@@ -39,6 +39,7 @@
filter_within_step_constraints,
reconcile_start_to_equality,
)
+from skillmodels.common.measurement_models import measurement_family_arrays
from skillmodels.common.model_spec import ModelSpec
from skillmodels.common.types import ProcessedModel, to_plain_dict
@@ -132,6 +133,7 @@ def estimate_initial_period( # noqa: PLR0915
params_index,
normalizations,
period=0,
+ bounds_distance=af_options.bounds_distance,
)
# Initialize parameters via simple heuristics
@@ -188,6 +190,10 @@ def estimate_initial_period( # noqa: PLR0915
all_measures, state_latent_factors, measurements_p0_filtered
)
+ families, lowers, uppers = measurement_family_arrays(
+ model_spec.measurement_models, all_measures
+ )
+
# Halton quadrature nodes: dimension equals the state-latent count
# (observed factors are conditioned on, not integrated over, via the
# Schur complement).
@@ -226,6 +232,9 @@ def estimate_initial_period( # noqa: PLR0915
"weights": weights,
"stability_floor": af_options.stability_floor,
"n_obs_per_batch": n_obs_per_batch,
+ "measurement_families": jnp.asarray(families),
+ "measurement_lowers": jnp.asarray(lowers),
+ "measurement_uppers": jnp.asarray(uppers),
}
loglike_and_grad = create_loglike_and_gradient(
diff --git a/src/skillmodels/af/likelihood.py b/src/skillmodels/af/likelihood.py
index 6366227e..6f04bf0a 100644
--- a/src/skillmodels/af/likelihood.py
+++ b/src/skillmodels/af/likelihood.py
@@ -13,6 +13,53 @@
from jax import Array
from skillmodels.af.types import ChainLink
+from skillmodels.common.measurement_models import (
+ MeasurementFamily,
+ measurement_loglik_vec,
+)
+
+
+def _gaussian_family_arrays(n_measures: int) -> tuple[Array, Array, Array]:
+ """All-Gaussian `(families, lowers, uppers)` -- exact parity with `_log_normal_pdf`.
+
+ The default when a caller does not supply measurement families, so every
+ existing AF call evaluates each measurement as a Gaussian density unchanged.
+ """
+ return (
+ jnp.full(n_measures, int(MeasurementFamily.GAUSSIAN)),
+ jnp.full(n_measures, -jnp.inf),
+ jnp.full(n_measures, jnp.inf),
+ )
+
+
+def _resolve_families(
+ n_measures: int,
+ families: Array | None,
+ lowers: Array | None,
+ uppers: Array | None,
+) -> tuple[Array, Array, Array]:
+ """Return supplied family arrays, or all-Gaussian defaults when absent."""
+ if families is None or lowers is None or uppers is None:
+ return _gaussian_family_arrays(n_measures)
+ return families, lowers, uppers
+
+
+def _measurement_log_density(
+ y_meas: Array,
+ residual: Array,
+ meas_sds: Array,
+ families: Array,
+ lowers: Array,
+ uppers: Array,
+) -> Array:
+ """Per-measurement log density/probability through the shared family kernel.
+
+ `residual = y - eta` is the Gaussian residual, so `eta = y_meas - residual`.
+ For the Gaussian family this reduces exactly to `N(residual; 0, meas_sds)`, so
+ swapping `_log_normal_pdf(residual, 0, sd)` for this call is parity-preserving.
+ """
+ eta = y_meas - residual
+ return measurement_loglik_vec(y_meas, eta, meas_sds, families, lowers, uppers)
def af_per_obs_loglike_initial(
@@ -31,6 +78,9 @@ def af_per_obs_loglike_initial(
n_latent_factors: int | None = None,
observed_factor_values: Array | None = None,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
) -> Array:
"""Per-observation log-likelihood for the initial period (Step 0).
@@ -64,6 +114,9 @@ def af_per_obs_loglike_initial(
weights=weights,
stability_floor=stability_floor,
n_obs_per_batch=n_obs_per_batch,
+ measurement_families=measurement_families,
+ measurement_lowers=measurement_lowers,
+ measurement_uppers=measurement_uppers,
)
assert observed_factor_values is not None # noqa: S101
return _initial_loglike_per_obs_conditional(
@@ -82,6 +135,9 @@ def af_per_obs_loglike_initial(
n_latent=n_latent,
stability_floor=stability_floor,
n_obs_per_batch=n_obs_per_batch,
+ measurement_families=measurement_families,
+ measurement_lowers=measurement_lowers,
+ measurement_uppers=measurement_uppers,
)
@@ -101,6 +157,9 @@ def af_loglike_initial(
n_latent_factors: int | None = None,
observed_factor_values: Array | None = None,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
) -> Array:
"""Negative log-likelihood for the initial period (Step 0).
@@ -174,6 +233,13 @@ def af_loglike_initial(
``None`` falls back to ``jax.vmap`` (single kernel); a positive
integer uses ``jax.lax.map`` so the backward-pass tape only
retains one chunk at a time.
+ measurement_families: Optional shape-``(n_measures,)`` `MeasurementFamily`
+ codes per period-0 measurement. ``None`` evaluates every measurement as
+ Gaussian (exact parity with the previous behaviour).
+ measurement_lowers: Optional shape-``(n_measures,)`` Tobit lower bounds
+ (``-inf`` where not Tobit / not censored below).
+ measurement_uppers: Optional shape-``(n_measures,)`` Tobit upper bounds
+ (``+inf`` where not Tobit / not censored above).
Return:
Scalar negative log-likelihood.
@@ -194,6 +260,9 @@ def af_loglike_initial(
n_latent_factors=n_latent_factors,
observed_factor_values=observed_factor_values,
n_obs_per_batch=n_obs_per_batch,
+ measurement_families=measurement_families,
+ measurement_lowers=measurement_lowers,
+ measurement_uppers=measurement_uppers,
)
return -jnp.mean(log_likes)
@@ -290,6 +359,9 @@ def _initial_loglike_per_obs(
weights: Array,
n_obs_per_batch: int | None = None,
stability_floor: float,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
) -> Array:
"""Compute log-likelihood for each observation at the initial period.
@@ -302,6 +374,10 @@ def _initial_loglike_per_obs(
full_loadings = jnp.zeros((n_measures, n_factors))
full_loadings = full_loadings.at[loading_mask].set(loadings)
+ families, lowers, uppers = _resolve_families(
+ n_measures, measurement_families, measurement_lowers, measurement_uppers
+ )
+
# NaN-safety: build per-obs measurement mask and replace NaN entries
# with 0 so residuals stay finite. The mask is used inside the
# integral to zero out missing-measurement contributions.
@@ -315,7 +391,9 @@ def _initial_loglike_per_obs(
residuals_base = safe_measurements - control_contrib
@jax.checkpoint
- def _single_obs_loglike(residual_base: Array, mask_i: Array) -> Array:
+ def _single_obs_loglike(
+ residual_base: Array, y_meas: Array, mask_i: Array
+ ) -> Array:
"""Log-likelihood for a single observation, integrated over factors.
`jax.checkpoint` keeps the forward pass small: the per-observation
@@ -325,9 +403,13 @@ def _single_obs_loglike(residual_base: Array, mask_i: Array) -> Array:
"""
return _integrate_initial_single_obs(
residual_base=residual_base,
+ y_meas=y_meas,
meas_mask=mask_i,
full_loadings=full_loadings,
meas_sds=meas_sds,
+ measurement_families=families,
+ measurement_lowers=lowers,
+ measurement_uppers=uppers,
mixture_weights=mixture_weights,
mixture_means=mixture_means,
mixture_chol_covs=mixture_chol_covs,
@@ -339,6 +421,7 @@ def _single_obs_loglike(residual_base: Array, mask_i: Array) -> Array:
return _map_over_obs(
_single_obs_loglike,
residuals_base,
+ safe_measurements,
meas_mask,
n_obs_per_batch=n_obs_per_batch,
)
@@ -361,6 +444,9 @@ def _initial_loglike_per_obs_conditional(
n_latent: int,
stability_floor: float,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
) -> Array:
"""Per-observation log-likelihood with Schur-complement conditioning.
@@ -388,6 +474,10 @@ def _initial_loglike_per_obs_conditional(
full_loadings = jnp.zeros((n_measures, n_latent))
full_loadings = full_loadings.at[loading_mask].set(loadings)
+ families, lowers, uppers = _resolve_families(
+ n_measures, measurement_families, measurement_lowers, measurement_uppers
+ )
+
# NaN-safety for measurements (see `_initial_loglike_per_obs`).
meas_mask = jnp.isfinite(measurements)
safe_measurements = jnp.where(meas_mask, measurements, 0.0)
@@ -396,13 +486,19 @@ def _initial_loglike_per_obs_conditional(
residuals_base = safe_measurements - control_contrib
@jax.checkpoint
- def _single_obs_loglike(residual_base: Array, y_i: Array, mask_i: Array) -> Array:
+ def _single_obs_loglike(
+ residual_base: Array, y_meas: Array, y_i: Array, mask_i: Array
+ ) -> Array:
return _integrate_initial_single_obs_conditional(
residual_base=residual_base,
y_i=y_i,
+ y_meas=y_meas,
meas_mask=mask_i,
full_loadings=full_loadings,
meas_sds=meas_sds,
+ measurement_families=families,
+ measurement_lowers=lowers,
+ measurement_uppers=uppers,
mixture_weights=mixture_weights,
mixture_means=mixture_means,
mixture_chol_covs=mixture_chol_covs,
@@ -415,6 +511,7 @@ def _single_obs_loglike(residual_base: Array, y_i: Array, mask_i: Array) -> Arra
return _map_over_obs(
_single_obs_loglike,
residuals_base,
+ safe_measurements,
observed_factor_values,
meas_mask,
n_obs_per_batch=n_obs_per_batch,
@@ -425,9 +522,13 @@ def _integrate_initial_single_obs_conditional(
*,
residual_base: Array,
y_i: Array,
+ y_meas: Array,
meas_mask: Array,
full_loadings: Array,
meas_sds: Array,
+ measurement_families: Array,
+ measurement_lowers: Array,
+ measurement_uppers: Array,
mixture_weights: Array,
mixture_means: Array,
mixture_chol_covs: Array,
@@ -481,7 +582,14 @@ def _component_log_kernel(l_idx: Array) -> Array:
def _log_node(z_q: Array) -> Array:
theta_q = cond_mean + cond_chol @ z_q
residuals = residual_base - full_loadings @ theta_q
- log_pdf = _log_normal_pdf(residuals, jnp.zeros_like(residuals), meas_sds)
+ log_pdf = _measurement_log_density(
+ y_meas,
+ residuals,
+ meas_sds,
+ measurement_families,
+ measurement_lowers,
+ measurement_uppers,
+ )
return jnp.sum(jnp.where(meas_mask, log_pdf, 0.0))
log_meas = jax.vmap(_log_node)(nodes)
@@ -509,9 +617,13 @@ def _log_mvn_pdf_chol(x: Array, mean: Array, chol: Array) -> Array:
def _integrate_initial_single_obs(
*,
residual_base: Array,
+ y_meas: Array,
meas_mask: Array,
full_loadings: Array,
meas_sds: Array,
+ measurement_families: Array,
+ measurement_lowers: Array,
+ measurement_uppers: Array,
mixture_weights: Array,
mixture_means: Array,
mixture_chol_covs: Array,
@@ -556,9 +668,17 @@ def _node_contribution(z_q: Array) -> Array:
# Measurement residuals: obs - control_contrib - loadings @ theta
residuals = residual_base - full_loadings @ theta_q
- # Log measurement density: sum of log N(residual_m, 0, sd_m),
- # masking out missing measurements (NaN replaced by 0 upstream).
- log_pdf = _log_normal_pdf(residuals, jnp.zeros_like(residuals), meas_sds)
+ # Log measurement density per family (Gaussian density, probit/Tobit
+ # probability), masking out missing measurements (NaN replaced by 0
+ # upstream so a masked row's contribution is dropped by `jnp.where`).
+ log_pdf = _measurement_log_density(
+ y_meas,
+ residuals,
+ meas_sds,
+ measurement_families,
+ measurement_lowers,
+ measurement_uppers,
+ )
log_meas_density = jnp.sum(jnp.where(meas_mask, log_pdf, 0.0))
total = total + mixture_weights[l_idx] * jnp.exp(log_meas_density)
@@ -603,6 +723,14 @@ def af_per_obs_loglike_transition(
n_shock_factors: int | None = None,
shock_factor_indices: Array | None = None,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
+ prev_measurement_families: Array | None = None,
+ prev_measurement_lowers: Array | None = None,
+ prev_measurement_uppers: Array | None = None,
+ target_control_tensor: Array | None = None,
+ prev_control_contrib: Array | None = None,
) -> Array:
"""Per-observation log-likelihood for a transition period (Step t).
@@ -637,7 +765,11 @@ def af_per_obs_loglike_transition(
prev_full_loadings = prev_full_loadings.at[prev_loading_mask].set(
prev_loadings_flat
)
- prev_control_contrib = prev_controls @ prev_control_params.T
+ # On the calendar-adapter path the importance control contribution is precompiled
+ # per row at its own control_period (static factors at 0, source skills at s); the
+ # plain path falls back to one shared source-period matmul.
+ if prev_control_contrib is None:
+ prev_control_contrib = prev_controls @ prev_control_params.T
# NaN-safety for prev-period measurements (see `_initial_loglike_per_obs`).
prev_meas_mask = jnp.isfinite(prev_measurements)
safe_prev_measurements = jnp.where(prev_meas_mask, prev_measurements, 0.0)
@@ -653,8 +785,10 @@ def af_per_obs_loglike_transition(
meas_sds=parsed["meas_sds"],
measurements=measurements,
controls=controls,
+ target_control_tensor=target_control_tensor,
loading_mask=loading_mask,
prev_residuals_base=prev_residuals_base,
+ prev_measurements_safe=safe_prev_measurements,
prev_meas_mask=prev_meas_mask,
prev_full_loadings=prev_full_loadings,
prev_meas_sds=prev_meas_sds,
@@ -672,6 +806,12 @@ def af_per_obs_loglike_transition(
observed_factor_values=observed_factor_values,
stability_floor=stability_floor,
n_obs_per_batch=n_obs_per_batch,
+ measurement_families=measurement_families,
+ measurement_lowers=measurement_lowers,
+ measurement_uppers=measurement_uppers,
+ prev_measurement_families=prev_measurement_families,
+ prev_measurement_lowers=prev_measurement_lowers,
+ prev_measurement_uppers=prev_measurement_uppers,
)
@@ -706,6 +846,14 @@ def af_loglike_transition(
n_shock_factors: int | None = None,
shock_factor_indices: Array | None = None,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
+ prev_measurement_families: Array | None = None,
+ prev_measurement_lowers: Array | None = None,
+ prev_measurement_uppers: Array | None = None,
+ target_control_tensor: Array | None = None,
+ prev_control_contrib: Array | None = None,
) -> Array:
"""Negative log-likelihood for a transition period (Step t).
@@ -788,6 +936,22 @@ def af_loglike_transition(
``None`` falls back to ``jax.vmap`` (single kernel); a positive
integer uses ``jax.lax.map`` so the backward-pass tape only
retains one chunk at a time.
+ measurement_families: Optional shape-``(n_measures,)`` `MeasurementFamily`
+ codes for the current period. ``None`` is all-Gaussian (parity).
+ measurement_lowers: Optional shape-``(n_measures,)`` current-period Tobit
+ lower bounds (``-inf`` where not censored below).
+ measurement_uppers: Optional shape-``(n_measures,)`` current-period Tobit
+ upper bounds (``+inf`` where not censored above).
+ prev_measurement_families: Optional shape-``(n_prev_measures,)`` family
+ codes for the previous-period measurement block. ``None`` is Gaussian.
+ prev_measurement_lowers: Optional previous-period Tobit lower bounds.
+ prev_measurement_uppers: Optional previous-period Tobit upper bounds.
+ target_control_tensor: Optional shape-``(n_obs, n_measures, n_controls)``
+ per-row target control data, each row at its own ``control_period``
+ (calendar-adapter path). ``None`` keeps the single shared-matrix matmul.
+ prev_control_contrib: Optional shape-``(n_obs, n_prev_measures)`` precompiled
+ importance control contribution, each row at its ``control_period``.
+ ``None`` computes it from the shared previous-period controls matrix.
Return:
Scalar negative log-likelihood.
@@ -823,6 +987,14 @@ def af_loglike_transition(
n_shock_factors=n_shock_factors,
shock_factor_indices=shock_factor_indices,
n_obs_per_batch=n_obs_per_batch,
+ measurement_families=measurement_families,
+ measurement_lowers=measurement_lowers,
+ measurement_uppers=measurement_uppers,
+ prev_measurement_families=prev_measurement_families,
+ prev_measurement_lowers=prev_measurement_lowers,
+ prev_measurement_uppers=prev_measurement_uppers,
+ target_control_tensor=target_control_tensor,
+ prev_control_contrib=prev_control_contrib,
)
return -jnp.mean(log_likes)
@@ -895,7 +1067,9 @@ def _transition_loglike_per_obs(
measurements: Array,
controls: Array,
loading_mask: Array,
+ target_control_tensor: Array | None = None,
prev_residuals_base: Array,
+ prev_measurements_safe: Array,
prev_meas_mask: Array,
prev_full_loadings: Array,
prev_meas_sds: Array,
@@ -913,6 +1087,12 @@ def _transition_loglike_per_obs(
observed_factor_values: Array,
stability_floor: float,
n_obs_per_batch: int | None = None,
+ measurement_families: Array | None = None,
+ measurement_lowers: Array | None = None,
+ measurement_uppers: Array | None = None,
+ prev_measurement_families: Array | None = None,
+ prev_measurement_lowers: Array | None = None,
+ prev_measurement_uppers: Array | None = None,
) -> Array:
"""Compute per-observation log-likelihood for a transition period.
@@ -936,11 +1116,31 @@ def _transition_loglike_per_obs(
full_loadings = jnp.zeros((n_measures, n_loading_factors))
full_loadings = full_loadings.at[loading_mask].set(loadings_flat)
+ families, lowers, uppers = _resolve_families(
+ n_measures, measurement_families, measurement_lowers, measurement_uppers
+ )
+ n_prev_measures = prev_full_loadings.shape[0]
+ prev_families, prev_lowers, prev_uppers = _resolve_families(
+ n_prev_measures,
+ prev_measurement_families,
+ prev_measurement_lowers,
+ prev_measurement_uppers,
+ )
+
# NaN-safety for current-period measurements (see `_initial_loglike_per_obs`).
meas_mask = jnp.isfinite(measurements)
safe_measurements = jnp.where(meas_mask, measurements, 0.0)
- control_contrib = controls @ control_params.T
+ # On the calendar-adapter path each target row carries its own control_period, so
+ # the control data is a per-row tensor `(n_obs, n_measures, n_controls)` and the
+ # contribution is an einsum with the free control params. The plain path keeps the
+ # single shared-matrix matmul (one destination-period controls matrix).
+ if target_control_tensor is None:
+ control_contrib = controls @ control_params.T
+ else:
+ control_contrib = jnp.einsum(
+ "omc,mc->om", target_control_tensor, control_params
+ )
residuals_base = safe_measurements - control_contrib
cond_weights = prev_distribution["cond_weights"]
@@ -953,7 +1153,9 @@ def _transition_loglike_per_obs(
@jax.checkpoint
def _single_obs(
residual_base: Array,
+ y_meas: Array,
prev_residual_base: Array,
+ prev_y_meas: Array,
obs_cond_weights: Array,
obs_factor_values: Array,
obs_cond_means: Array,
@@ -963,13 +1165,21 @@ def _single_obs(
) -> Array:
return _integrate_transition_single_obs(
residual_base=residual_base,
+ y_meas=y_meas,
meas_mask=meas_mask_i,
full_loadings=full_loadings,
meas_sds=meas_sds,
+ measurement_families=families,
+ measurement_lowers=lowers,
+ measurement_uppers=uppers,
prev_residual_base=prev_residual_base,
+ prev_y_meas=prev_y_meas,
prev_meas_mask=prev_meas_mask_i,
prev_full_loadings=prev_full_loadings,
prev_meas_sds=prev_meas_sds,
+ prev_measurement_families=prev_families,
+ prev_measurement_lowers=prev_lowers,
+ prev_measurement_uppers=prev_uppers,
obs_cond_weights=obs_cond_weights,
obs_cond_means=obs_cond_means,
cond_chols=cond_chols,
@@ -994,7 +1204,9 @@ def _single_obs(
return _map_over_obs(
_single_obs,
residuals_base,
+ safe_measurements,
prev_residuals_base,
+ prev_measurements_safe,
cond_weights,
observed_factor_values,
cond_means_by_obs,
@@ -1114,13 +1326,21 @@ def _rebuild_chain_at_period(
def _integrate_transition_single_obs(
*,
residual_base: Array,
+ y_meas: Array,
meas_mask: Array,
full_loadings: Array,
meas_sds: Array,
+ measurement_families: Array,
+ measurement_lowers: Array,
+ measurement_uppers: Array,
prev_residual_base: Array,
+ prev_y_meas: Array,
prev_meas_mask: Array,
prev_full_loadings: Array,
prev_meas_sds: Array,
+ prev_measurement_families: Array,
+ prev_measurement_lowers: Array,
+ prev_measurement_uppers: Array,
obs_cond_weights: Array | np.ndarray,
obs_cond_means: Array | np.ndarray,
cond_chols: Array | np.ndarray,
@@ -1245,10 +1465,13 @@ def _log_draw_contribution(j_idx: Array) -> Array:
# a per-obs constant that is invariant under the parameters.
prev_state_loadings = prev_full_loadings[:, state_factor_indices_in_latent]
prev_residuals = prev_residual_base - prev_state_loadings @ theta_prev
- prev_log_pdf = _log_normal_pdf(
+ prev_log_pdf = _measurement_log_density(
+ prev_y_meas,
prev_residuals,
- jnp.zeros_like(prev_residuals),
prev_meas_sds,
+ prev_measurement_families,
+ prev_measurement_lowers,
+ prev_measurement_uppers,
)
log_prev_inv_meas = jnp.sum(jnp.where(prev_meas_mask, prev_log_pdf, 0.0))
@@ -1264,16 +1487,24 @@ def _log_draw_contribution(j_idx: Array) -> Array:
transition_func(full_prev_with_obs, transition_params)
+ state_shock_contrib
)
- # Investment calendar (audit F8): `inv` is the endogenous investment
- # GENERATED from theta_{t-1} (it just drove the transition to theta_t),
- # i.e. I_{t-1}. The period-t measurement block below scores it, so an
- # endogenous factor's period-t indicators measure I_{t-1}, not the
- # contemporaneous I_t. This is the MATLAB-faithful AF convention;
- # `validate_af_model` warns so it is not silently compared against the
- # CHS/AMN reading (period-t indicators measure I_t).
+ # Investment calendar: `inv` is the endogenous investment GENERATED from
+ # theta_{t-1} that drove the transition to theta_t -- i.e. I_s for step
+ # s=t-1 -> d=t. @pro: this kernel is shared and UNCHANGED by the calendar
+ # adapter; correctness now comes from the inputs. Under the adapter the
+ # current block's `y_meas`/`full_loadings`/`residual_base` are the layout's
+ # SOURCE-period investment indicators, so `full_loadings @ [theta_t, inv]`
+ # scores I_s indicators on I_s (contemporaneous in the source period).
+ # Confirm the re-sourced inputs flow correctly into this unchanged kernel.
all_factors_t = jnp.concatenate([theta_t, inv])
residuals = residual_base - full_loadings @ all_factors_t
- log_pdf = _log_normal_pdf(residuals, jnp.zeros_like(residuals), meas_sds)
+ log_pdf = _measurement_log_density(
+ y_meas,
+ residuals,
+ meas_sds,
+ measurement_families,
+ measurement_lowers,
+ measurement_uppers,
+ )
log_meas = jnp.sum(jnp.where(meas_mask, log_pdf, 0.0))
log_kernel = (
diff --git a/src/skillmodels/af/params.py b/src/skillmodels/af/params.py
index 45a00a08..f2013d49 100644
--- a/src/skillmodels/af/params.py
+++ b/src/skillmodels/af/params.py
@@ -121,6 +121,7 @@ def get_transition_period_params_index(
endogenous_factors: tuple[str, ...] = (),
observed_factors: tuple[str, ...] = (),
shock_factors: tuple[str, ...] | None = None,
+ measurement_index_tuples: list[tuple[str, int, str, str]] | None = None,
) -> pd.MultiIndex:
"""Build parameter index for a transition period (Step t, t >= 1).
@@ -141,6 +142,9 @@ def get_transition_period_params_index(
SD is estimated. Factors omitted here get no shock SD parameter
and are integrated deterministically (dropping their shock
dimension from the Halton draw). Defaults to `latent_factors`.
+ measurement_index_tuples: Pre-compiled mixed-calendar measurement index
+ rows (source/destination calendar adapter). When given, used verbatim
+ for the measurement block instead of single-period emission.
Return:
MultiIndex with levels (category, period, name1, name2).
@@ -173,17 +177,23 @@ def get_transition_period_params_index(
# Investment shock SD
ind_tups.append(("investment_sds", period - 1, endog_factor, "-"))
- # Measurement params for period t (loadings for ALL factors, not just state)
- all_factor_measurements = dict(measurements_at_period)
- all_latent = (*latent_factors, *endogenous_factors)
- ind_tups.extend(
- _measurement_index_tuples(
- period=period,
- latent_factors=all_latent,
- measurements=all_factor_measurements,
- controls=controls,
+ # Measurement params. By default emit period-t rows for all factors; when the
+ # source/destination calendar adapter supplies a pre-compiled mixed-calendar
+ # measurement index (destination skills at d, source investment at s, in global
+ # category order), use it verbatim instead.
+ if measurement_index_tuples is not None:
+ ind_tups.extend(measurement_index_tuples)
+ else:
+ all_factor_measurements = dict(measurements_at_period)
+ all_latent = (*latent_factors, *endogenous_factors)
+ ind_tups.extend(
+ _measurement_index_tuples(
+ period=period,
+ latent_factors=all_latent,
+ measurements=all_factor_measurements,
+ controls=controls,
+ )
)
- )
return pd.MultiIndex.from_tuples(
ind_tups,
diff --git a/src/skillmodels/af/posterior_states.py b/src/skillmodels/af/posterior_states.py
index 599f25f0..0aebf72f 100644
--- a/src/skillmodels/af/posterior_states.py
+++ b/src/skillmodels/af/posterior_states.py
@@ -56,6 +56,12 @@ def get_af_posterior_states(
(columns: id, period, factor1, ...) and "state_ranges".
"""
+ if model_spec != af_result.model_spec:
+ msg = (
+ "get_af_posterior_states received a model_spec that does not match "
+ "af_result.model_spec; pass the same specification used for estimation."
+ )
+ raise ValueError(msg)
jax.config.update("jax_enable_x64", val=True)
idx_names = data.index.names
@@ -81,21 +87,42 @@ def get_af_posterior_states(
if not measurements_pt:
continue
+ period_mask = data.index.get_level_values(period_col) == t
+ period_df = data.loc[period_mask]
+ ids = period_df.index.get_level_values(id_col)
+
+ # Keep only the measurements that load on a STATE factor and are present in
+ # the data. A state factor's own indicators are parsed correctly at their own
+ # period, so the single-period extraction below is valid for them. Under the
+ # source/destination calendar adapter the reconstructed investment's indicators
+ # load on no state factor and are sourced from a different step (their params
+ # live at a different `param_period`), so the single-period parse would misread
+ # them; drop them. Investment is endogenous, not a state coordinate, so it is
+ # not reported either way. The resulting posterior conditions on the
+ # state-loading measurements plus the period-0 chained sample; it does not fold
+ # in the investment indicators' information about the state, so it is slightly
+ # less precise than (but unbiased relative to) the full-information posterior.
+ # For plain models every measurement loads on a state factor, so this is a
+ # no-op and the result is byte-identical to the pre-migration output.
+ state_meas_pt = {
+ f: tuple(m for m in measurements_pt[f] if m in period_df.columns)
+ for f in state_factors
+ if f in measurements_pt
+ and any(m in period_df.columns for m in measurements_pt[f])
+ }
+ kept = _get_ordered_measures(state_meas_pt)
+ if not kept:
+ continue
+
meas_info = _extract_period_measurement_info(
period_result.params,
- model_spec,
state_factors,
t,
+ kept,
+ state_meas_pt,
)
-
- period_mask = data.index.get_level_values(period_col) == t
- period_df = data.loc[period_mask]
- ids = period_df.index.get_level_values(id_col)
-
- all_measures = _get_ordered_measures(measurements_pt)
- meas_cols = [c for c in all_measures if c in period_df.columns]
measurements = jnp.array(
- period_df[meas_cols].to_numpy(dtype=np.float64, na_value=np.nan),
+ period_df[kept].to_numpy(dtype=np.float64, na_value=np.nan),
)
# Build per-observation control contribution
@@ -140,17 +167,23 @@ def get_af_posterior_states(
def _extract_period_measurement_info(
period_params: pd.DataFrame,
- model_spec: ModelSpec,
factors: tuple[str, ...],
period: int,
+ measures: list[str],
+ measurements_pt: dict[str, tuple[str, ...]],
) -> dict[str, Any]:
- """Extract measurement loadings, control contribution, and SDs."""
- measurements_pt = get_measurements_per_factor(model_spec.factors, period=period)
- all_measures = _get_ordered_measures(measurements_pt)
- loading_mask = _build_loading_mask(all_measures, factors, measurements_pt)
+ """Extract measurement loadings, control contribution, and SDs.
+
+ `measures` is the ordered list of measurement columns to score against the
+ state `factors`, and `measurements_pt` maps each state factor to those of its
+ measures. Reconstructed investment indicators, which load on no state factor,
+ are excluded by the caller so the single-period parse never reaches for
+ source-period rows (and `_build_loading_mask` never indexes a non-state factor).
+ """
+ loading_mask = _build_loading_mask(measures, factors, measurements_pt)
loadings_list = []
- for mi, meas in enumerate(all_measures):
+ for mi, meas in enumerate(measures):
for fi, factor in enumerate(factors):
if loading_mask[mi, fi]:
loc = ("loadings", period, meas, factor)
@@ -159,7 +192,7 @@ def _extract_period_measurement_info(
float(period_params.loc[loc, "value"]) # ty: ignore[invalid-argument-type]
)
- full_loadings = jnp.zeros((len(all_measures), len(factors)))
+ full_loadings = jnp.zeros((len(measures), len(factors)))
full_loadings = full_loadings.at[jnp.array(loading_mask)].set( # noqa: PD008
jnp.array(loadings_list)
)
@@ -174,22 +207,20 @@ def _extract_period_measurement_info(
else ["constant"]
)
ctrl_params_list = []
- for meas in all_measures:
+ for meas in measures:
for ctrl in ctrl_names:
loc = ("controls", period, meas, ctrl)
if loc in period_params.index:
ctrl_params_list.append(float(period_params.loc[loc, "value"]))
else:
ctrl_params_list.append(0.0)
- control_params = jnp.array(ctrl_params_list).reshape(
- len(all_measures), len(ctrl_names)
- )
+ control_params = jnp.array(ctrl_params_list).reshape(len(measures), len(ctrl_names))
sd_list = [
float(period_params.loc[loc, "value"]) # ty: ignore[invalid-argument-type]
if (loc := ("meas_sds", period, meas, "-")) in period_params.index
else 0.5
- for meas in all_measures
+ for meas in measures
]
return {
diff --git a/src/skillmodels/af/step_assembly.py b/src/skillmodels/af/step_assembly.py
new file mode 100644
index 00000000..509e7eb7
--- /dev/null
+++ b/src/skillmodels/af/step_assembly.py
@@ -0,0 +1,204 @@
+"""Assemble an AF transition step's target + importance arrays from a compiled layout.
+
+Combines the layout (which says *which* measurement belongs in each block and from which
+period) with the per-period ID-indexed data frames and the cumulative `HistoricalParams`
+to produce the arrays the integrand consumes:
+
+- the free **target** block (destination skills + source investment) -- data + loading
+ mask;
+- the fixed **importance** block (source skills + every static-persistent factor's
+ period-0 rows) -- data, loading mask, and fixed loadings/controls/SDs read from
+ history by each row's true `param_period`.
+
+All columns are sourced on one common individual-ID set (the intersection across every
+period the step touches), so target and importance rows refer to the same individuals.
+"""
+
+from collections.abc import Mapping, Sequence
+from dataclasses import dataclass
+
+import numpy as np
+import pandas as pd
+
+from skillmodels.af.step_data import build_block_loading_mask
+from skillmodels.af.step_layout import (
+ AFMeasurementTerm,
+ AFStepLayout,
+ HistoricalParams,
+)
+
+
+@dataclass(frozen=True)
+class AFStepArrays:
+ """Target + importance measurement arrays for one AF transition step."""
+
+ ids: np.ndarray
+ """The step's common individual-ID vector (sorted intersection)."""
+ target_measurements: np.ndarray
+ """`(n_ids, n_target)` free target measurement values."""
+ target_loading_mask: np.ndarray
+ """`(n_target, n_latent)` target loading mask."""
+ target_order: list[tuple[int, tuple[str, int]]]
+ """`(column, (measurement, data_period))` for each target column."""
+ importance_measurements: np.ndarray
+ """`(n_ids, n_importance)` fixed importance measurement values."""
+ importance_loading_mask: np.ndarray
+ """`(n_importance, n_latent)` importance loading mask."""
+ imp_order: list[tuple[int, tuple[str, int]]]
+ """`(column, (measurement, data_period))` for each importance column."""
+ importance_loadings_flat: np.ndarray
+ """Fixed importance loadings, packed in loading-mask order, from history."""
+ importance_control_params: np.ndarray
+ """`(n_importance, n_controls)` fixed importance control params, from history."""
+ importance_meas_sds: np.ndarray
+ """`(n_importance,)` fixed importance measurement SDs, from history."""
+ target_controls: np.ndarray
+ """`(n_ids, n_target, n_controls)` per-row target control data, each row sourced
+ from its term's own `control_period` (the free target control coefficients are
+ estimated, so only the data -- not the contribution -- can be precompiled)."""
+ importance_control_contrib: np.ndarray
+ """`(n_ids, n_importance)` fixed importance control contribution, each row's control
+ data (at its `control_period`) dotted with its fixed control params from history."""
+
+
+def assemble_step_arrays(
+ layout: AFStepLayout,
+ frames_by_period: Mapping[int, pd.DataFrame],
+ latent_factors: Sequence[str],
+ historical: HistoricalParams,
+ controls: tuple[str, ...],
+) -> AFStepArrays:
+ """Assemble the target + importance arrays for one transition step.
+
+ Args:
+ layout: The compiled step layout.
+ frames_by_period: Calendar period -> frame indexed by individual ID.
+ latent_factors: Latent-factor ordering the integrand dots loadings against.
+ historical: Cumulative parameter registry for fixed importance params.
+ controls: Control variable names (for importance control params).
+
+ Return:
+ An `AFStepArrays` with target/importance data, masks, and fixed params.
+
+ """
+ # @pro: calendar split for step s->d. The FREE target block = {destination skills
+ # on theta_d, source investment I_s on period s}; the FIXED importance block =
+ # {source skills on theta_s, every static-persistent factor's period-0 rows}. The
+ # importance loadings/controls/SDs are read from `historical` by each row's true
+ # param_period (period-0 for static factors), so the carry-over density is fixed at
+ # its originally estimated value. Confirm target/importance membership and the
+ # history sourcing are the right ones for the MATLAB sequential likelihood.
+ targets = layout.target_terms()
+ importances = layout.importance_terms()
+
+ needed = {term.data_period for term in (*targets, *importances)}
+ common: set | None = None
+ for period in needed:
+ period_ids = set(frames_by_period[period].index)
+ common = period_ids if common is None else (common & period_ids)
+ # Preserve the individual-ID values as-is (object dtype) rather than coercing to
+ # int64: the adapter's contract is stable ID alignment, which must hold for string,
+ # UUID, or other non-integer identifiers, not just CNLSY-style numeric case IDs.
+ ids = np.array(sorted(common or set()), dtype=object)
+
+ target_order, target_values = _source_block(targets, frames_by_period, ids)
+ imp_order, imp_values = _source_block(importances, frames_by_period, ids)
+ target_mask = build_block_loading_mask(targets, latent_factors)
+ imp_mask = build_block_loading_mask(importances, latent_factors)
+
+ loadings_flat: list[float] = []
+ for row, term in enumerate(importances):
+ for factor_idx, factor in enumerate(latent_factors):
+ if imp_mask[row, factor_idx]:
+ loadings_flat.append(
+ historical.value(
+ "loadings", term.param_period, term.measurement, factor
+ )
+ )
+ control_params = np.array(
+ [
+ [
+ historical.value(
+ "controls", term.control_period, term.measurement, ctrl
+ )
+ for ctrl in controls
+ ]
+ for term in importances
+ ],
+ dtype=np.float64,
+ ).reshape(len(importances), len(controls))
+ meas_sds = np.array(
+ [
+ historical.value("meas_sds", term.param_period, term.measurement, "-")
+ for term in importances
+ ],
+ dtype=np.float64,
+ )
+
+ target_controls = _source_control_tensor(targets, frames_by_period, ids, controls)
+ imp_controls = _source_control_tensor(importances, frames_by_period, ids, controls)
+ # Importance control params are fixed (from history), so the contribution can be
+ # precompiled per row at its own control_period: (n_ids, n_imp, n_ctrl) . (n_imp,
+ # n_ctrl) -> (n_ids, n_imp). Target control params are estimated, so only the data
+ # tensor is precompiled; the kernel forms its contribution each iteration.
+ importance_control_contrib = np.einsum(
+ "imc,mc->im", imp_controls, control_params, optimize=False
+ )
+
+ return AFStepArrays(
+ ids=ids,
+ target_measurements=target_values,
+ target_loading_mask=target_mask,
+ target_order=list(enumerate(target_order)),
+ importance_measurements=imp_values,
+ importance_loading_mask=imp_mask,
+ imp_order=list(enumerate(imp_order)),
+ importance_loadings_flat=np.array(loadings_flat, dtype=np.float64),
+ importance_control_params=control_params,
+ importance_meas_sds=meas_sds,
+ target_controls=target_controls,
+ importance_control_contrib=importance_control_contrib,
+ )
+
+
+def _source_block(
+ terms: Sequence[AFMeasurementTerm],
+ frames_by_period: Mapping[int, pd.DataFrame],
+ ids: np.ndarray,
+) -> tuple[list[tuple[str, int]], np.ndarray]:
+ """Source a block's columns on a shared ID set, each from its term's data period."""
+ order = [(term.measurement, term.data_period) for term in terms]
+ if not terms:
+ return order, np.zeros((len(ids), 0), dtype=np.float64)
+ columns = [
+ frames_by_period[term.data_period]
+ .loc[ids, term.measurement]
+ .to_numpy(dtype=np.float64)
+ for term in terms
+ ]
+ return order, np.column_stack(columns)
+
+
+def _source_control_tensor(
+ terms: Sequence[AFMeasurementTerm],
+ frames_by_period: Mapping[int, pd.DataFrame],
+ ids: np.ndarray,
+ controls: tuple[str, ...],
+) -> np.ndarray:
+ """Source each term's control values on the shared IDs from its `control_period`.
+
+ Returns shape `(n_ids, n_terms, n_controls)`. A `"constant"` control is a column of
+ ones; a named control is read from that term's `control_period` frame; an absent
+ control is zero. Reading each row at its own `control_period` is what keeps a
+ mixed-calendar block (e.g. destination skills at `d` with source investment at `s`)
+ from evaluating every row against one shared period's controls.
+ """
+ tensor = np.zeros((len(ids), len(terms), len(controls)), dtype=np.float64)
+ for row, term in enumerate(terms):
+ frame = frames_by_period[term.control_period]
+ for col, ctrl in enumerate(controls):
+ if ctrl == "constant":
+ tensor[:, row, col] = 1.0
+ elif ctrl in frame.columns:
+ tensor[:, row, col] = frame.loc[ids, ctrl].to_numpy(dtype=np.float64)
+ return tensor
diff --git a/src/skillmodels/af/step_data.py b/src/skillmodels/af/step_data.py
new file mode 100644
index 00000000..c509888b
--- /dev/null
+++ b/src/skillmodels/af/step_data.py
@@ -0,0 +1,84 @@
+"""ID-aligned assembly of an AF step's mixed-calendar measurement arrays.
+
+A compiled `AFStepLayout` draws measurement columns from different calendar periods
+(destination skills at `d`, source investment at `s`, static factors at 0). Their
+per-period frames can differ in individual order and sample, so columns must be joined
+on individual ID -- never concatenated positionally -- and each column sourced from its
+term's own `data_period`.
+"""
+
+from collections.abc import Mapping, Sequence
+
+import numpy as np
+import pandas as pd
+
+from skillmodels.af.step_layout import AFMeasurementTerm
+
+
+def build_step_measurement_array(
+ terms: Sequence[AFMeasurementTerm],
+ frames_by_period: Mapping[int, pd.DataFrame],
+) -> tuple[np.ndarray, np.ndarray, list[tuple[str, int]]]:
+ """Assemble one ID-aligned measurement value array for a set of step terms.
+
+ Each term's column is read from `frames_by_period[term.data_period]` at the term's
+ measurement variable. The step sample is the intersection of individual IDs across
+ every sourced period, sorted ascending, so columns from different periods are paired
+ by ID rather than by row position.
+
+ Args:
+ terms: The measurement terms (e.g. a layout's target or importance terms).
+ frames_by_period: Calendar period -> frame indexed by individual ID with
+ measurement-variable columns.
+
+ Return:
+ `(ids, values, term_order)` where `ids` is the sorted common ID vector, `values`
+ is `(n_ids, n_terms)` with `values[:, j]` the `j`-th term's column, and
+ `term_order` lists `(measurement, data_period)` per column.
+
+ """
+ if not terms:
+ ids_empty = np.array([], dtype=object)
+ return ids_empty, np.zeros((0, 0), dtype=np.float64), []
+
+ needed_periods = {term.data_period for term in terms}
+ common_ids: set | None = None
+ for period in needed_periods:
+ period_ids = set(frames_by_period[period].index)
+ common_ids = period_ids if common_ids is None else (common_ids & period_ids)
+ # Preserve ID values as-is (object dtype) rather than coercing to int64, so
+ # string/UUID identifiers align rather than raise (matches `assemble_step_arrays`).
+ ids = np.array(sorted(common_ids or set()), dtype=object)
+
+ columns = []
+ term_order: list[tuple[str, int]] = []
+ for term in terms:
+ frame = frames_by_period[term.data_period]
+ column = frame.loc[ids, term.measurement].to_numpy(dtype=np.float64)
+ columns.append(column)
+ term_order.append((term.measurement, term.data_period))
+
+ values = (
+ np.column_stack(columns)
+ if columns
+ else np.zeros((len(ids), 0), dtype=np.float64)
+ )
+ return ids, values, term_order
+
+
+def build_block_loading_mask(
+ terms: Sequence[AFMeasurementTerm],
+ latent_factors: Sequence[str],
+) -> np.ndarray:
+ """Build the `(n_terms, n_latent_factors)` boolean loading mask for a block.
+
+ Row `j` is `True` in the `latent_factors` columns that term `j` loads on, matching
+ the latent-factor ordering the integrand dots loadings against. The returned array
+ has boolean dtype and shape `(len(terms), len(latent_factors))`.
+ """
+ factor_index = {name: i for i, name in enumerate(latent_factors)}
+ mask = np.zeros((len(terms), len(latent_factors)), dtype=np.bool_)
+ for row, term in enumerate(terms):
+ for factor in term.factor_loadings:
+ mask[row, factor_index[factor]] = True
+ return mask
diff --git a/src/skillmodels/af/step_layout.py b/src/skillmodels/af/step_layout.py
new file mode 100644
index 00000000..4b90b105
--- /dev/null
+++ b/src/skillmodels/af/step_layout.py
@@ -0,0 +1,342 @@
+"""Source/destination calendar layout for the AF estimator.
+
+The AF estimator is sequential: step `s -> d = s+1` jointly estimates the `s -> d`
+transition, the period-`s` investment equation, and a measurement block. The public
+`ModelSpec` is contemporaneous -- an investment indicator declared at calendar period
+`c` measures `I_c` -- so a calendar-to-step adapter is needed to feed AF's blocks. This
+module compiles that adapter as row-level `AFMeasurementTerm`s grouped into one
+`AFStepLayout` per transition.
+
+For step `s -> d` the compiled blocks (validated against MATLAB `likelihood_01`/
+`likelihood_12` + `create_nodes_weights_12`) are:
+
+- FREE **target**: destination dynamic-state (skill) indicators scored on `theta_d`,
+ and source endogenous (investment) indicators scored on `I_s`.
+- FIXED **importance**: source dynamic-state indicators scored on `theta_s`, plus every
+ static-persistent factor's period-0 indicators scored on its time-invariant value
+ (re-applied at every step -- the dropped-MC/MN importance fix).
+
+Carrying `family`/`lower`/`upper` on each term lets limited-dependent (probit/Tobit)
+measurements compose with the shared measurement kernel without a second calendar pass.
+"""
+
+import enum
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from dataclasses import dataclass
+from typing import Literal, SupportsFloat, cast
+
+import pandas as pd
+
+from skillmodels.common.measurement_models import MeasurementFamily
+from skillmodels.common.model_spec import ModelSpec
+
+# A term's role: a free, this-step target density vs a fixed, historical importance
+# reweighting density.
+AFTermRole = Literal["target", "importance"]
+
+
+def model_uses_calendar_adapter(model_spec: ModelSpec) -> bool:
+ """Return whether a model takes the AF source/destination calendar-adapter path.
+
+ The adapter activates for a reconstructed endogenous factor (`is_endogenous` with
+ `has_initial_distribution=False`) or any static-persistent factor. Plain AF models
+ keep the single-period path.
+ """
+ specs = model_spec.factors.values()
+ reconstructed_endog = any(
+ spec.is_endogenous and not spec.has_initial_distribution for spec in specs
+ )
+ static_persistent = any(spec.af_state_role == "static_persistent" for spec in specs)
+ return reconstructed_endog or static_persistent
+
+
+def fail_if_calendar_adapter_unsupported(model_spec: ModelSpec, feature: str) -> None:
+ """Raise if `feature` is requested for a calendar-adapter model.
+
+ AF standard-error inference and posterior-state extraction still reconstruct the
+ single-period measurement layout, which is wrong for the mixed-calendar adapter:
+ source-investment params live in the next step's result and the per-step parser
+ counts differ. Reject adapter models on those paths -- rather than return standard
+ errors that differentiate the wrong objective or posterior rows attached to the
+ wrong latent -- until they consume the compiled `AFStepLayout`.
+ """
+ if model_uses_calendar_adapter(model_spec):
+ msg = (
+ f"{feature} is not yet supported for AF models that use the "
+ "source/destination calendar adapter (a reconstructed-endogenous or "
+ "static-persistent factor). That path must be migrated to the compiled "
+ "AFStepLayout before it returns valid results."
+ )
+ raise NotImplementedError(msg)
+
+
+def fail_if_spearman_unsupported_on_adapter(
+ model_spec: ModelSpec, strategy: str
+) -> None:
+ """Raise if a Spearman moment start is requested for a calendar-adapter model.
+
+ The Spearman/OLS start routine discovers measurements at the destination period and
+ writes loading/SD rows there, so it does not seed the mixed-calendar target
+ (destination skills at `d` plus source investment at `s`). Rather than silently fall
+ back to the constant start -- which contradicts the requested option -- reject it on
+ the adapter path until a layout-aware moment initializer exists.
+ """
+ if strategy == "spearman" and model_uses_calendar_adapter(model_spec):
+ msg = (
+ "start_params_strategy='spearman' is not supported for AF models that use "
+ "the source/destination calendar adapter: the moment start discovers "
+ "measurements at the destination period and would mis-seed the "
+ "source-investment block. Use start_params_strategy='constant' (or 'amn') "
+ "until a calendar-aware moment initializer is implemented."
+ )
+ raise NotImplementedError(msg)
+
+
+class AFFactorRole(enum.Enum):
+ """How a latent factor participates in AF's sequential calendar."""
+
+ DYNAMIC = "dynamic"
+ """Real transition, measured across periods (e.g. skills)."""
+ STATIC_PERSISTENT = "static_persistent"
+ """Time-invariant; its period-0 measurement density is re-applied as an importance
+ factor at every transition step (e.g. MC, MN)."""
+ ENDOGENOUS = "endogenous"
+ """Reconstructed from the investment equation (`has_initial_distribution=False`)."""
+
+
+class AFEval(enum.Enum):
+ """Which assembled latent vector a measurement term is scored against."""
+
+ THETA_DEST = "theta_dest"
+ """Destination dynamic state `theta_d`."""
+ THETA_SRC = "theta_src"
+ """Source dynamic state `theta_s`."""
+ INV_SRC = "inv_src"
+ """Source endogenous reconstruction `I_s = g(theta_s, Y_s)`."""
+ STATIC = "static"
+ """Time-invariant static-persistent latent coordinate."""
+
+
+@dataclass(frozen=True)
+class AFFactorInfo:
+ """A factor's AF role and its per-calendar-period measurement declarations."""
+
+ name: str
+ """Factor name."""
+ role: AFFactorRole
+ """The factor's AF calendar role."""
+ measurements_by_period: tuple[tuple[str, ...], ...]
+ """Per-period tuple of measurement variable names (empty where inactive)."""
+
+
+@dataclass(frozen=True)
+class AFMeasurementTerm:
+ """One observed measurement equation instance within an AF step."""
+
+ measurement: str
+ """Observed measurement variable name."""
+ factor_loadings: tuple[str, ...]
+ """Latent factors this row loads on (a single factor for non-cross-loaded rows)."""
+ data_period: int
+ """Calendar period the observed values are read from."""
+ param_period: int
+ """Calendar period the measurement params are indexed under."""
+ control_period: int
+ """Calendar period the controls are read from."""
+ eval_node: AFEval
+ """Which assembled latent vector the loadings are scored against."""
+ family: MeasurementFamily
+ """Measurement family (Gaussian/probit/Tobit)."""
+ lower: float
+ """Tobit lower censoring bound (`-inf` if not applicable)."""
+ upper: float
+ """Tobit upper censoring bound (`+inf` if not applicable)."""
+ role: AFTermRole
+ """`"target"` (free, this step) or `"importance"` (fixed, historical)."""
+ free: bool
+ """Whether this row's params are estimated at this step vs fixed from history."""
+
+
+@dataclass(frozen=True)
+class AFStepLayout:
+ """The compiled measurement layout for one AF transition step `s -> d`."""
+
+ source_period: int
+ """The source calendar period `s`."""
+ destination_period: int
+ """The destination calendar period `d = s + 1`."""
+ terms: tuple[AFMeasurementTerm, ...]
+ """All target and importance measurement terms for the step."""
+ equation_period: int
+ """Calendar period indexing the transition, investment-eq, and shock params."""
+ observed_factor_period: int
+ """Calendar period the observed factors (income) entering `g` are read from."""
+
+ def target_terms(self) -> tuple[AFMeasurementTerm, ...]:
+ """Return the free, this-step target terms."""
+ return tuple(t for t in self.terms if t.role == "target")
+
+ def importance_terms(self) -> tuple[AFMeasurementTerm, ...]:
+ """Return the fixed, historical importance terms."""
+ return tuple(t for t in self.terms if t.role == "importance")
+
+
+@dataclass(frozen=True)
+class HistoricalParams:
+ """Cumulative AF parameter registry keyed by the full param MultiIndex.
+
+ Importance terms re-apply fixed period-0 static-factor (MC/MN) densities, whose
+ params live in the initial-step result, not the immediately-previous step result.
+ A single concatenated registry makes every earlier estimate reachable.
+ """
+
+ table: pd.DataFrame
+ """Concatenated per-step params with a unique MultiIndex and a `value` column."""
+
+ @classmethod
+ def from_param_frames(cls, frames: Iterable[pd.DataFrame]) -> HistoricalParams:
+ """Concatenate per-step param frames, rejecting duplicate index entries."""
+ combined = pd.concat(list(frames))
+ if combined.index.has_duplicates:
+ dups = sorted(set(combined.index[combined.index.duplicated()].tolist()))
+ msg = f"Duplicate parameter index entries across AF steps: {dups}"
+ raise ValueError(msg)
+ return cls(table=combined)
+
+ def value(self, category: str, period: int, name1: str, name2: str) -> float:
+ """Return the estimated value at one full-index coordinate."""
+ # The index is unique (validated), so this `.loc` is a scalar; pandas-stubs
+ # still type it as a broad union, hence the cast to a float-coercible value.
+ cell = self.table.loc[(category, period, name1, name2), "value"]
+ return float(cast("SupportsFloat", cell))
+
+
+def compile_target_measurement_index(
+ layout: AFStepLayout,
+ controls: tuple[str, ...],
+) -> list[tuple[str, int, str, str]]:
+ """Build the free target measurement param index in the flat parser's global order.
+
+ Emits all control rows, then all loading rows, then all measurement-SD rows (the
+ order `_parse_transition_params` expects), each tagged with its term's true
+ `param_period` so a mixed-calendar target (destination skills at `d`, source
+ investment at `s`) parses correctly.
+ """
+ targets = layout.target_terms()
+ ind_tups: list[tuple[str, int, str, str]] = []
+ for term in targets:
+ for ctrl in controls:
+ ind_tups.append(("controls", term.param_period, term.measurement, ctrl))
+ for term in targets:
+ for factor in term.factor_loadings:
+ ind_tups.append(("loadings", term.param_period, term.measurement, factor))
+ for term in targets:
+ ind_tups.append(("meas_sds", term.param_period, term.measurement, "-"))
+ return ind_tups
+
+
+def compile_af_step_layouts(
+ factor_infos: Sequence[AFFactorInfo],
+ n_periods: int,
+ families: Mapping[str, tuple[MeasurementFamily, float, float]] | None = None,
+) -> tuple[AFStepLayout, ...]:
+ """Compile one `AFStepLayout` per transition step from contemporaneous declarations.
+
+ Args:
+ factor_infos: Per-factor AF role + per-period measurement declarations.
+ n_periods: Number of calendar periods (transitions are `0..n_periods-2`).
+ families: Optional measurement -> `(family, lower, upper)`; absent measurements
+ default to Gaussian.
+
+ Return:
+ One `AFStepLayout` per transition step `s -> s+1`.
+
+ """
+ fam_map = dict(families) if families is not None else {}
+ layouts = []
+ for source in range(n_periods - 1):
+ destination = source + 1
+ terms: list[AFMeasurementTerm] = []
+ for info in factor_infos:
+ terms.extend(_terms_for_factor(info, source, destination, fam_map))
+ layouts.append(
+ AFStepLayout(
+ source_period=source,
+ destination_period=destination,
+ terms=tuple(terms),
+ equation_period=source,
+ observed_factor_period=source,
+ )
+ )
+ return tuple(layouts)
+
+
+def _terms_for_factor(
+ info: AFFactorInfo,
+ source: int,
+ destination: int,
+ fam_map: Mapping[str, tuple[MeasurementFamily, float, float]],
+) -> list[AFMeasurementTerm]:
+ """Build the step's measurement terms contributed by one factor."""
+ if info.role == AFFactorRole.DYNAMIC:
+ return [
+ *_terms_at(
+ info, destination, AFEval.THETA_DEST, "target", free=True, fam=fam_map
+ ),
+ *_terms_at(
+ info, source, AFEval.THETA_SRC, "importance", free=False, fam=fam_map
+ ),
+ ]
+ if info.role == AFFactorRole.ENDOGENOUS:
+ # @pro: THE calendar fix. The endogenous (investment) indicators are sourced
+ # from the SOURCE period s and scored on I_s -- the investment that drives the
+ # s->d transition -- not from destination d (the prior I_{d-1} mispairing).
+ # Confirm this matches MATLAB likelihood_01/12, where the period-s investment
+ # block enters transition_{s->d}.
+ return _terms_at(info, source, AFEval.INV_SRC, "target", free=True, fam=fam_map)
+ # STATIC_PERSISTENT: only the period-0 indicators are re-applied as importance.
+ # @pro: a static factor's period-0 measurement block is re-emitted as a FIXED
+ # importance term at every step s->d (period-0 rows only, re-applied each step), so
+ # its time-invariant density carries over in the importance weight without any later
+ # period's declaration leaking into an earlier step. Is re-anchoring the static
+ # block at its declared period-0 the correct carry-over weight at every later step?
+ return _terms_at(info, 0, AFEval.STATIC, "importance", free=False, fam=fam_map)
+
+
+def _terms_at(
+ info: AFFactorInfo,
+ period: int,
+ eval_node: AFEval,
+ role: AFTermRole,
+ *,
+ free: bool,
+ fam: Mapping[str, tuple[MeasurementFamily, float, float]],
+) -> list[AFMeasurementTerm]:
+ """Build terms for one factor's measurements at a single calendar period."""
+ measures = (
+ info.measurements_by_period[period]
+ if period < len(info.measurements_by_period)
+ else ()
+ )
+ terms = []
+ for measure in measures:
+ family, lower, upper = fam.get(
+ measure, (MeasurementFamily.GAUSSIAN, -math.inf, math.inf)
+ )
+ terms.append(
+ AFMeasurementTerm(
+ measurement=measure,
+ factor_loadings=(info.name,),
+ data_period=period,
+ param_period=period,
+ control_period=period,
+ eval_node=eval_node,
+ family=family,
+ lower=lower,
+ upper=upper,
+ role=role,
+ free=free,
+ )
+ )
+ return terms
diff --git a/src/skillmodels/af/transition_period.py b/src/skillmodels/af/transition_period.py
index 37eaab8b..6d736e49 100644
--- a/src/skillmodels/af/transition_period.py
+++ b/src/skillmodels/af/transition_period.py
@@ -35,6 +35,15 @@
get_normalizations_for_period,
get_transition_period_params_index,
)
+from skillmodels.af.step_assembly import AFStepArrays, assemble_step_arrays
+from skillmodels.af.step_layout import (
+ AFFactorInfo,
+ AFFactorRole,
+ HistoricalParams,
+ compile_af_step_layouts,
+ compile_target_measurement_index,
+ model_uses_calendar_adapter,
+)
from skillmodels.af.types import (
AFEstimationOptions,
AFPeriodResult,
@@ -51,6 +60,7 @@
filter_within_step_constraints,
reconcile_start_to_equality,
)
+from skillmodels.common.measurement_models import GaussianMeasurement
from skillmodels.common.model_spec import ModelSpec
from skillmodels.common.types import ProcessedModel, TransitionInfo, to_plain_dict
@@ -72,6 +82,8 @@ def estimate_transition_period(
start_params: pd.DataFrame | None = None,
fixed_params: pd.DataFrame | None = None,
user_constraints: list[om.constraints.Constraint] | None = None,
+ frames_by_period: Mapping[int, pd.DataFrame] | None = None,
+ historical: HistoricalParams | None = None,
) -> tuple[AFPeriodResult, ConditionalDistribution]:
"""Estimate a transition period (Step t, t >= 1) of the AF procedure.
@@ -102,6 +114,12 @@ def estimate_transition_period(
from `estimate_af(constraints=...)`. Entries whose members
all sit in this step's params index are appended to the
step's `om.minimize` call (within-step equalities).
+ frames_by_period: Per-period individual-ID-indexed measurement frames,
+ required for endogenous / static-persistent models (the source/
+ destination calendar adapter sources the target block from them).
+ historical: Cumulative parameter registry; supplies the fixed importance
+ block (source skills + static-persistent period-0 rows) when the
+ calendar adapter is active.
Return:
Tuple of (AFPeriodResult, ConditionalDistribution). The returned
@@ -123,9 +141,6 @@ def estimate_transition_period(
factors = processed_model.labels.latent_factors
controls_names = processed_model.labels.controls
- measurements_pt = get_measurements_per_factor(model_spec.factors, period=period)
- all_measures = _get_ordered_measures(measurements_pt)
-
transition_info = processed_model.transition_info
state_factors = tuple(f for f in factors if f not in endogenous_factors)
@@ -149,21 +164,25 @@ def estimate_transition_period(
[factors.index(f) for f in state_factors], dtype=jnp.int32
)
- params_index = get_transition_period_params_index(
- period=period,
- latent_factors=state_factors,
- transition_info=transition_info,
- measurements_at_period=measurements_pt,
- controls=controls_names,
- endogenous_factors=endogenous_factors,
- observed_factors=observed_factors,
- shock_factors=shock_factors,
- )
normalizations = get_normalizations_for_period(model_spec.factors, period=period)
- params_template = create_af_params_template(
- params_index,
- normalizations,
- period=period,
+
+ measurements, all_measures, loading_mask, params_template, step_arrays = (
+ _assemble_target_measurements(
+ period=period,
+ model_spec=model_spec,
+ factors=factors,
+ controls_names=controls_names,
+ state_factors=state_factors,
+ shock_factors=shock_factors,
+ transition_info=transition_info,
+ endogenous_factors=endogenous_factors,
+ observed_factors=observed_factors,
+ measurements=measurements,
+ frames_by_period=frames_by_period,
+ historical=historical,
+ normalizations=normalizations,
+ bounds_distance=af_options.bounds_distance,
+ )
)
params_template = _initialize_transition_params(
@@ -194,9 +213,6 @@ def estimate_transition_period(
params_template, transition_constraints, fixed_params
)
- # Build loading mask
- loading_mask = _build_loading_mask(all_measures, factors, measurements_pt)
-
# JOINT Halton design covering ALL randomness needed at this step,
# mirroring MATLAB's `create_nodes_weights_01/12`. The chained sample
# θ_0 → θ_{period-1} is rebuilt on-demand inside the integrand from
@@ -319,6 +335,7 @@ def combined_transition(
transition_constraints=transition_constraints,
fixed_params=fixed_params,
user_constraints=user_constraints,
+ importance=step_arrays,
)
# Build the next ChainLink from the just-fitted period parameters and
@@ -365,6 +382,234 @@ def combined_transition(
return period_result, updated_dist
+def _assemble_target_measurements(
+ *,
+ period: int,
+ model_spec: ModelSpec,
+ factors: tuple[str, ...],
+ controls_names: tuple[str, ...],
+ state_factors: tuple[str, ...],
+ shock_factors: tuple[str, ...],
+ transition_info: TransitionInfo,
+ endogenous_factors: tuple[str, ...],
+ observed_factors: tuple[str, ...],
+ measurements: Array,
+ frames_by_period: Mapping[int, pd.DataFrame] | None,
+ historical: HistoricalParams | None,
+ normalizations: dict[str, dict[tuple[str, str], float]],
+ bounds_distance: float = 0.001,
+) -> tuple[Array, list[str], np.ndarray, pd.DataFrame, AFStepArrays | None]:
+ """Build the target measurements, names, loading mask, params template, arrays.
+
+ With endogenous or static-persistent factors the target block mixes calendars
+ (destination skills at period d, source investment at period s), sourced from the
+ compiled layout; plain models keep the single-period path, byte-identical to before.
+ """
+ # The calendar adapter applies to RECONSTRUCTED investment (endogenous with
+ # has_initial_distribution=False): its period-0 indicators are excluded from the
+ # initial step and scored once, at the 0->1 step, against I_0. Legacy endogenous
+ # factors that keep an initial distribution stay on the single-period path.
+ reconstructed_endog = tuple(
+ f
+ for f in endogenous_factors
+ if not model_spec.factors[f].has_initial_distribution
+ )
+ use_layout = bool(reconstructed_endog) or any(
+ model_spec.factors[f].af_state_role == "static_persistent"
+ for f in state_factors
+ )
+ if not use_layout:
+ measurements_pt = get_measurements_per_factor(model_spec.factors, period=period)
+ all_measures = _get_ordered_measures(measurements_pt)
+ loading_mask = _build_loading_mask(all_measures, factors, measurements_pt)
+ params_index = get_transition_period_params_index(
+ period=period,
+ latent_factors=state_factors,
+ transition_info=transition_info,
+ measurements_at_period=measurements_pt,
+ controls=controls_names,
+ endogenous_factors=endogenous_factors,
+ observed_factors=observed_factors,
+ shock_factors=shock_factors,
+ )
+ params_template = create_af_params_template(
+ params_index, normalizations, period=period, bounds_distance=bounds_distance
+ )
+ return measurements, all_measures, loading_mask, params_template, None
+
+ if frames_by_period is None or historical is None:
+ msg = (
+ "frames_by_period and historical are required for AF models with "
+ "endogenous or static-persistent factors."
+ )
+ raise ValueError(msg)
+ _fail_if_endogenous_precedes_state(factors, endogenous_factors)
+ _fail_if_unsupported_adapter_measurements(model_spec)
+ n_calendar_periods = max(
+ len(spec.measurements) for spec in model_spec.factors.values()
+ )
+ layout = compile_af_step_layouts(
+ _factor_infos_from_spec(model_spec, endogenous_factors),
+ n_periods=n_calendar_periods,
+ )[period - 1]
+ target_terms = layout.target_terms()
+ # @pro: the free target block for this step is the mixed-calendar set {destination
+ # skills at d, source investment I_s at s}; the params index, template, and
+ # normalizations below are indexed at each row's own param_period so the mixed
+ # calendar parses correctly. The matching fixed importance block is built in
+ # `_run_transition_optimization`.
+ step_arrays = assemble_step_arrays(
+ layout, frames_by_period, factors, historical, controls_names
+ )
+ measurements = jnp.asarray(step_arrays.target_measurements)
+ all_measures = [term.measurement for term in target_terms]
+ loading_mask = step_arrays.target_loading_mask
+ params_index = get_transition_period_params_index(
+ period=period,
+ latent_factors=state_factors,
+ transition_info=transition_info,
+ measurements_at_period={},
+ controls=controls_names,
+ endogenous_factors=endogenous_factors,
+ observed_factors=observed_factors,
+ shock_factors=shock_factors,
+ measurement_index_tuples=compile_target_measurement_index(
+ layout, controls_names
+ ),
+ )
+ params_template = create_af_params_template(
+ params_index, {}, period=period, bounds_distance=bounds_distance
+ )
+ params_template = _apply_layout_normalizations(
+ params_template, model_spec, target_terms
+ )
+ return measurements, all_measures, loading_mask, params_template, step_arrays
+
+
+def _factor_infos_from_spec(
+ model_spec: ModelSpec,
+ endogenous_factors: tuple[str, ...],
+) -> list[AFFactorInfo]:
+ """Build per-factor AF role + measurement info for the layout compiler.
+
+ The source-investment calendar is defined only for a *reconstructed* endogenous
+ factor -- one with `has_initial_distribution=False`. An endogenous factor that still
+ carries an initial distribution is not a reconstructed investment, so it cannot take
+ the `ENDOGENOUS` source-period role; the adapter does not support it and raises.
+ """
+ infos: list[AFFactorInfo] = []
+ for name, spec in model_spec.factors.items():
+ if name in endogenous_factors:
+ if spec.has_initial_distribution:
+ msg = (
+ f"AF calendar adapter: endogenous factor {name!r} has "
+ "has_initial_distribution=True. The source-investment calendar is "
+ "only defined for reconstructed endogenous factors "
+ "(has_initial_distribution=False); a carried endogenous factor is "
+ "not supported on the adapter path."
+ )
+ raise ValueError(msg)
+ role = AFFactorRole.ENDOGENOUS
+ elif spec.af_state_role == "static_persistent":
+ role = AFFactorRole.STATIC_PERSISTENT
+ else:
+ role = AFFactorRole.DYNAMIC
+ infos.append(
+ AFFactorInfo(name=name, role=role, measurements_by_period=spec.measurements)
+ )
+ return infos
+
+
+def _fail_if_endogenous_precedes_state(
+ latent_factors: tuple[str, ...],
+ endogenous_factors: tuple[str, ...],
+) -> None:
+ """Reject a public factor order that interleaves endogenous before state factors.
+
+ The shared integrand assembles the latent vector as all dynamic-state factors
+ followed by all reconstructed-endogenous factors, while the loading mask columns
+ follow `latent_factors` (public insertion order). If an endogenous factor appears
+ before a state factor in public order, those two orderings disagree and a loading
+ row would score the wrong latent. Until the two representations are unified, require
+ every state factor to precede every endogenous factor and raise otherwise.
+ """
+ endo = set(endogenous_factors)
+ seen_endogenous = False
+ for name in latent_factors:
+ if name in endo:
+ seen_endogenous = True
+ elif seen_endogenous:
+ msg = (
+ "AF calendar adapter: all dynamic-state factors must precede every "
+ f"endogenous factor in the ModelSpec, but state factor {name!r} "
+ "follows an endogenous factor. Reorder so state factors come first."
+ )
+ raise ValueError(msg)
+
+
+def _fail_if_unsupported_adapter_measurements(model_spec: ModelSpec) -> None:
+ """Reject cross-loaded or non-Gaussian measurements on the calendar-adapter path.
+
+ The compiler emits one single-factor density term per measurement declaration and
+ does not plumb measurement-family metadata through the adapter, so a cross-loaded
+ measurement (declared under more than one factor) would be double-counted as two
+ rows, and a non-Gaussian (probit/Tobit) measurement would be silently scored as
+ Gaussian. Both are rejected until row-merging and family plumbing land.
+ """
+ owner: dict[str, str] = {}
+ for factor_name, spec in model_spec.factors.items():
+ for period_measures in spec.measurements:
+ for measure in period_measures:
+ if measure in owner and owner[measure] != factor_name:
+ msg = (
+ f"AF calendar adapter: measurement {measure!r} is "
+ f"cross-loaded on factors {owner[measure]!r} and "
+ f"{factor_name!r}; cross-loaded measurements are not supported."
+ )
+ raise ValueError(msg)
+ owner[measure] = factor_name
+ for measure in owner:
+ model = model_spec.measurement_models.get(measure)
+ if model is not None and not isinstance(model, GaussianMeasurement):
+ msg = (
+ f"AF calendar adapter: measurement {measure!r} has a non-Gaussian "
+ f"family ({type(model).__name__}); the adapter supports only Gaussian "
+ "measurements."
+ )
+ raise ValueError(msg)
+
+
+def _apply_layout_normalizations(
+ params_template: pd.DataFrame,
+ model_spec: ModelSpec,
+ target_terms: tuple,
+) -> pd.DataFrame:
+ """Pin loading/intercept normalizations at each target term's true param period.
+
+ The mixed-calendar target indexes destination skills at `d` and source investment
+ at `s`, so normalizations must be applied per row's `param_period` rather than at a
+ single period (which `create_af_params_template` assumes).
+ """
+ params = params_template
+ for term in target_terms:
+ period = term.param_period
+ for factor in term.factor_loadings:
+ norms = model_spec.factors[factor].normalizations
+ if norms is None:
+ continue
+ if norms.loadings is not None and period < len(norms.loadings):
+ val = norms.loadings[period].get(term.measurement)
+ loc = ("loadings", period, term.measurement, factor)
+ if val is not None and loc in params.index:
+ params.loc[loc, ["value", "lower_bound", "upper_bound"]] = val
+ if norms.intercepts is not None and period < len(norms.intercepts):
+ val = norms.intercepts[period].get(term.measurement)
+ loc = ("controls", period, term.measurement, "constant")
+ if val is not None and loc in params.index:
+ params.loc[loc, ["value", "lower_bound", "upper_bound"]] = val
+ return params
+
+
def _run_transition_optimization(
*,
params_template: pd.DataFrame,
@@ -398,6 +643,7 @@ def _run_transition_optimization(
transition_constraints: list[om.constraints.Constraint],
fixed_params: pd.DataFrame | None,
user_constraints: list[om.constraints.Constraint] | None = None,
+ importance: AFStepArrays | None = None,
) -> tuple[pd.DataFrame, om.OptimizeResult]:
"""Build likelihood, run the optimizer, and return updated params.
@@ -413,12 +659,30 @@ def _run_transition_optimization(
params_template, fixed_params
)
- prev_meas_info = _extract_prev_measurement_params(
- prev_period_params,
- model_spec,
- factors,
- period - 1,
- )
+ # Importance/previous block. @pro: with the calendar adapter, the importance block
+ # is the layout's fixed source-skills (+ static-persistent period-0) rows from
+ # history -- it EXCLUDES the source investment (now the free target), avoiding the
+ # double-count, and RE-INCLUDES the static factors' period-0 density at every step
+ # (the dropped-MC/MN fix). Without the adapter, fall back to extracting all
+ # previous-period measurement params (the legacy single-calendar path).
+ if importance is not None:
+ prev_measurements = jnp.asarray(importance.importance_measurements)
+ # `assemble_step_arrays` returns host (numpy) arrays; move them on-device so
+ # they satisfy the `jax.Array` kwargs contract under the beartype claw (mirrors
+ # the target `loading_mask` wrap below). Numerically a no-op.
+ prev_meas_info = {
+ "loading_mask": jnp.asarray(importance.importance_loading_mask),
+ "control_params": jnp.asarray(importance.importance_control_params),
+ "loadings_flat": jnp.asarray(importance.importance_loadings_flat),
+ "meas_sds": jnp.asarray(importance.importance_meas_sds),
+ }
+ else:
+ prev_meas_info = _extract_prev_measurement_params(
+ prev_period_params,
+ model_spec,
+ factors,
+ period - 1,
+ )
n_obs_per_batch = af_options.n_obs_per_batch
if n_obs_per_batch is None:
@@ -460,6 +724,16 @@ def _run_transition_optimization(
"stability_floor": af_options.stability_floor,
"n_obs_per_batch": n_obs_per_batch,
}
+ if importance is not None:
+ # Layout path: feed the per-row target control data and the precompiled
+ # importance control contribution, each sourced at its own control_period, so
+ # the kernel stops applying one shared period's controls to a mixed block.
+ loglike_kwargs["target_control_tensor"] = jnp.asarray(
+ importance.target_controls
+ )
+ loglike_kwargs["prev_control_contrib"] = jnp.asarray(
+ importance.importance_control_contrib
+ )
loglike_and_grad = create_loglike_and_gradient(
af_loglike_transition,
@@ -834,11 +1108,18 @@ def _initialize_transition_params(
# MLE neighborhood; for sigma_inv_0 specifically, this is the difference
# between converging at truth and drifting to the lower bound along
# the sigma_inv / sigma_meas constant-Var ridge.
+ # The Spearman moment override discovers measurements at the destination period and
+ # writes loading/SD rows there, so it is not aware of the mixed-calendar target
+ # (destination skills at d + source investment at s). On the calendar-adapter path
+ # it would mis-seed the source-investment block, so skip it and keep the constant
+ # defaults (which converge for the adapter); calendar-aware moment seeding is future
+ # work. The point estimate is unaffected -- only start values differ.
if (
af_options is not None
and af_options.start_params_strategy == "spearman"
and model_spec is not None
and period is not None
+ and not model_uses_calendar_adapter(model_spec)
):
params = _apply_moment_based_overrides_transition(
params,
diff --git a/src/skillmodels/af/types.py b/src/skillmodels/af/types.py
index 4b3be319..37160560 100644
--- a/src/skillmodels/af/types.py
+++ b/src/skillmodels/af/types.py
@@ -100,6 +100,15 @@ class AFEstimationOptions:
posterior-state summary precision matters for downstream analysis.
"""
+ bounds_distance: float
+ """Minimum distance from zero for SD parameters (their lower bound).
+
+ Also sets the corresponding distance from 1 for the upper bound of CES
+ share parameters. Defaults to 0.001. The Antweiler-Freyberger Monte Carlo
+ uses 0.01; raising it stops a weakly-identified SD (e.g. the investment
+ shock SD) from collapsing onto the floor as far.
+ """
+
def __init__( # noqa: D107
self,
n_halton_points: int = 50,
@@ -114,6 +123,7 @@ def __init__( # noqa: D107
start_params_strategy: Literal["none", "constant", "spearman", "amn"] = "amn",
keep_conditional_distributions: bool = True,
n_halton_points_posterior_summary: int = 256,
+ bounds_distance: float = 0.001,
) -> None:
if n_halton_points_posterior_summary < 1:
msg = (
@@ -142,6 +152,7 @@ def __init__( # noqa: D107
"n_halton_points_posterior_summary",
n_halton_points_posterior_summary,
)
+ object.__setattr__(self, "bounds_distance", bounds_distance)
@dataclass(frozen=True)
@@ -306,7 +317,11 @@ class AFPeriodResult:
"""Estimated parameters with 4-level MultiIndex (category, period, name1, name2)."""
loglikelihood: float
- """Log-likelihood value at the optimum."""
+ """Per-observation **mean** log-likelihood criterion at the optimum.
+
+ This is the negated optimiser objective `-mean(neg_log_like)`, i.e. an average over
+ observations, not a summed log-likelihood. Multiply by the step's sample size to
+ recover a total."""
success: bool
"""Whether optimization converged."""
@@ -336,8 +351,24 @@ class AFEstimationResult:
`skillmodels.common.estimation.CommonEstimationResult`."""
loglikelihood: float
- """Sum of the per-period log-likelihoods at the optimum (AF maximises a
- sequence of per-period likelihoods)."""
+ """Alias of `sequential_criterion`, retained for `CommonEstimationResult`
+ conformance and back-compat.
+
+ It is the sum of the per-period **mean** log-likelihood criteria, which is a
+ sequential/composite criterion, **not** a joint sample log-likelihood: each term is
+ a per-observation average and the static-factor (e.g. MC/MN) densities are
+ deliberately re-applied across independently optimised steps. It is therefore not
+ valid for AIC/BIC or likelihood-ratio comparisons; use it only as the optimiser
+ objective scale. See `period_mean_criteria` for the per-step breakdown."""
+
+ sequential_criterion: float = float("nan")
+ """The summed per-period mean log-likelihood criteria (same value as
+ `loglikelihood`, under its honest name). A sequential criterion, not a joint
+ log-likelihood — see `loglikelihood`."""
+
+ period_mean_criteria: tuple[float, ...] = ()
+ """Per-period mean log-likelihood criteria, ordered by period (each is the
+ corresponding `AFPeriodResult.loglikelihood`)."""
md_criterion: float | None = None
"""Always `None` for AF; present to satisfy the common result Protocol."""
diff --git a/src/skillmodels/af/validate.py b/src/skillmodels/af/validate.py
index f396d1b3..c328d366 100644
--- a/src/skillmodels/af/validate.py
+++ b/src/skillmodels/af/validate.py
@@ -81,13 +81,13 @@ def validate_af_model(
`estimate_af`) supply the alternative anchors. They default to None so the
measurement-system and transition checks can be run on a bare ModelSpec.
- Also emit a `UserWarning` (audit F8) when an endogenous factor (investment)
- carries measurements: AF reconstructs it from the previous period's skills, so
- its period-t indicators score the investment generated from period t-1
- (`I_{t-1}`), whereas CHS/AMN read the same period-t indicators as the
- contemporaneous `I_t`. The same `ModelSpec` therefore denotes different
- investment calendars across estimators; model investment as a standard
- (non-endogenous) factor to score the contemporaneous value under AF too.
+ A reconstructed endogenous factor (investment, `is_endogenous=True` and
+ `has_initial_distribution=False`) is contemporaneous: a period-c investment
+ indicator measures `I_c`, the same calendar CHS / AMN read. The calendar
+ adapter makes the public `ModelSpec` contemporaneous and the AF step
+ assembler re-times the source investment internally, so no calendar warning
+ is emitted for such factors -- a user who shifted their data to satisfy one
+ would reintroduce the original off-by-one bug.
Also emit a loud `UserWarning` (not an error) when a built-in production
transition function would silently absorb observed factors (income).
@@ -128,7 +128,6 @@ def validate_af_model(
warn_if_overrestricted(model_spec, fixed_params, constraints)
_warn_on_observed_factor_leakage(model_spec)
- _warn_on_endogenous_investment_calendar(model_spec)
if errors:
msg = "ModelSpec is not compatible with AF estimation:\n" + "\n".join(
@@ -204,21 +203,12 @@ def _validate_factor(factor_name: str, factor_spec: FactorSpec) -> list[str]:
f"supported for endogenous factors (set is_endogenous=True)."
)
- # Factors without an initial distribution must also not be measured at
- # period 0: their value at period 0 is not drawn from any mixture, so a
- # measurement density there would have no latent value to hit.
- if (
- not factor_spec.has_initial_distribution
- and len(factor_spec.measurements) > 0
- and len(factor_spec.measurements[0]) > 0
- ):
- errors.append(
- f"Factor '{factor_name}': has_initial_distribution=False requires "
- f"empty measurements at period 0 (got "
- f"{factor_spec.measurements[0]!r}). Drop them from the FactorSpec; "
- f"their contribution would typically be absorbed into the "
- f"transition step 0->1 in a MATLAB-style reproduction."
- )
+ # A reconstructed factor (has_initial_distribution=False) MAY declare period-0
+ # measurements under the contemporaneous convention: they measure the period-0
+ # reconstruction I_0 = g(theta_0). They are excluded from the initial step (which
+ # has no latent value for them) and sourced by the source/destination calendar
+ # adapter as the free target of the 0->1 step, scored against I_0. So period-0
+ # reconstructed-factor measurements are allowed and not an error here.
return errors
@@ -262,45 +252,6 @@ def _warn_on_observed_factor_leakage(model_spec: ModelSpec) -> None:
)
-def _warn_on_endogenous_investment_calendar(model_spec: ModelSpec) -> None:
- """Warn that an endogenous factor's measurements attach to the SOURCE period.
-
- AF reconstructs an endogenous factor (investment) from the PREVIOUS period's
- latent skills via the investment equation: at the (t-1)->t step the generated
- investment `I` is a function of theta_{t-1} (and period-(t-1) income), and it
- is that same `I` that the period-t measurement block scores. So measurements
- declared at period t for an endogenous factor measure the investment GENERATED
- FROM period t-1 (`I_{t-1}`), not the contemporaneous `I_t`. This is the
- MATLAB-faithful AF calendar (the investment generated alongside theta_{t-1}).
-
- CHS / AMN read the identical period-t measurements as the contemporaneous
- `I_t` -- a standard latent factor with its own initial distribution -- so the
- same `ModelSpec` denotes different investment calendars across estimators.
- Warn (not error) so cross-estimator comparison is an explicit choice: to score
- the contemporaneous investment under AF too, model it as a standard
- (non-endogenous) latent factor with its own initial distribution instead of an
- endogenous one.
- """
- for name, spec in model_spec.factors.items():
- # The calendar offset is the canonical AF endogenous reconstruction: the
- # factor has no initial distribution and is rebuilt each step from the
- # prior period's skills via the investment equation.
- if not (spec.is_endogenous and not spec.has_initial_distribution):
- continue
- if any(len(measures) > 0 for measures in spec.measurements):
- warnings.warn(
- f"Factor '{name}' is endogenous: AF generates it from the previous "
- f"period's skills, so its period-t measurements score the investment "
- f"generated from period t-1 (I_{{t-1}}), not the contemporaneous "
- f"I_t. CHS/AMN read the same measurements as I_t, so this ModelSpec "
- f"denotes different investment calendars across estimators. This is "
- f"the MATLAB-faithful AF convention; to score the contemporaneous "
- f"investment under AF as well, model it as a standard "
- f"(non-endogenous) latent factor with its own initial distribution.",
- stacklevel=3,
- )
-
-
def fail_if_unsupported_kappa_params(
start_params: pd.DataFrame | None,
fixed_params: pd.DataFrame | None,
diff --git a/src/skillmodels/amn/estimate.py b/src/skillmodels/amn/estimate.py
index 46a3fb9e..7805191f 100644
--- a/src/skillmodels/amn/estimate.py
+++ b/src/skillmodels/amn/estimate.py
@@ -37,6 +37,7 @@
fail_if_initial_state_unanchored,
warn_if_overrestricted,
)
+from skillmodels.common.measurement_models import GaussianMeasurement
from skillmodels.common.model_spec import ModelSpec
from skillmodels.common.process_model import process_model
from skillmodels.common.types import ProcessedModel
@@ -166,6 +167,32 @@ def _fit_stage1_mixture(
)
+def _fail_if_non_gaussian_measurements(model_spec: ModelSpec) -> None:
+ """Reject probit/Tobit measures: AMN's moment map assumes continuous Gaussians.
+
+ AMN recovers loadings and measurement SDs from the cross-covariance of a
+ factor's multiple indicators (mixture EM + minimum distance), which treats every
+ measure as a continuous Gaussian signal. A probit/Tobit measure routed through
+ that map would be silently mis-estimated, so it is refused here. The AMN->CHS
+ seeding path supplies a working-linear (Gaussian) spec, so it is unaffected.
+ """
+ non_gaussian = sorted(
+ name
+ for name, model in model_spec.measurement_models.items()
+ if not isinstance(model, GaussianMeasurement)
+ )
+ if non_gaussian:
+ msg = (
+ f"estimate_amn supports only Gaussian measurements, but "
+ f"{non_gaussian} are non-Gaussian (probit/Tobit). AMN's mixture-EM and "
+ "minimum-distance stages recover loadings and SDs from multi-indicator "
+ "cross-covariances, which assume continuous Gaussian measures. Estimate "
+ "such models with estimate_af, or translate the measures to a working "
+ "linear system before seeding."
+ )
+ raise NotImplementedError(msg)
+
+
@beartype(conf=ESTIMATION_CONF)
def estimate_amn(
model_spec: ModelSpec,
@@ -228,6 +255,7 @@ def estimate_amn(
params DataFrame.
"""
+ _fail_if_non_gaussian_measurements(model_spec)
if require_initial_anchors and not for_start_values:
fail_if_initial_state_unanchored(model_spec, fixed_params, constraints)
warn_if_overrestricted(model_spec, fixed_params, constraints)
diff --git a/src/skillmodels/common/constraints.py b/src/skillmodels/common/constraints.py
index 19f9a06d..285aac29 100644
--- a/src/skillmodels/common/constraints.py
+++ b/src/skillmodels/common/constraints.py
@@ -325,7 +325,9 @@ def get_constraints(
n_mixtures=dimensions.n_mixtures,
factors=labels.latent_factors,
)
- constraints += _get_transition_constraints(labels=labels)
+ constraints += _get_transition_constraints(
+ labels=labels, endogenous_factors_info=endogenous_factors_info
+ )
constraints += _get_anchoring_constraints(
update_info=update_info,
controls=labels.controls,
@@ -539,21 +541,28 @@ def _get_initial_states_constraints(
def _get_transition_constraints(
labels: Labels,
+ endogenous_factors_info: EndogenousFactorsInfo,
) -> list[om.constraints.Constraint]:
"""Collect possible constraints on transition parameters.
Args:
labels: Dict of lists with labels for the model quantities like
factors, periods, controls, stagemap and stages. See :ref:`labels`
+ endogenous_factors_info: Information about endogenous factors in the model.
Returns:
List of constraint objects.
"""
+ # Transition params (and therefore their simplex/probability folds) live on
+ # `aug_periods[:-2]` when endogenous factors split the calendar into augmented
+ # periods, else `aug_periods[:-1]`. Mirror `get_transition_index_tuples` exactly
+ # so no constraint targets a transition row absent from the params index.
+ end = -2 if endogenous_factors_info.has_endogenous_factors else -1
constraints: list[om.constraints.Constraint] = []
for f, factor in enumerate(labels.latent_factors):
tname = labels.transition_names[f]
- for aug_period in labels.aug_periods[:-1]:
+ for aug_period in labels.aug_periods[:end]:
funcname = f"constraints_{tname}"
if func := getattr(t_f_module, funcname, False):
constraints.append(
diff --git a/src/skillmodels/common/measurement_models.py b/src/skillmodels/common/measurement_models.py
new file mode 100644
index 00000000..9a500c95
--- /dev/null
+++ b/src/skillmodels/common/measurement_models.py
@@ -0,0 +1,191 @@
+"""Shared measurement-family log-likelihood kernel.
+
+One JAX-compatible kernel evaluates the log density/probability contribution of a
+single measurement given its linear predictor ``eta = c + x'beta + lambda'theta``,
+its scale, and a `MeasurementFamily` code. AF estimation, CHS, simulation and
+posterior-state reweighting all route their measurement contributions through this
+kernel, so the measurement system is defined in exactly one place.
+
+The three families:
+
+- ``GAUSSIAN`` -- ``log N(y; eta, sigma)``.
+- ``PROBIT`` -- ``log Phi((2y-1) eta)`` for ``y in {0, 1}``; the latent error SD is
+ fixed to 1, so ``sigma`` is ignored. Evaluated through `jax.scipy.special.log_ndtr`
+ (never ``log(1 - ndtr)``) so the tails stay finite and stable.
+- ``TOBIT`` -- censored normal with known bounds: the interior normal log density on
+ ``L < y < U``, the left tail mass ``Phi((L-eta)/sigma)`` at ``y = L``, and the
+ right tail mass ``Phi((eta-U)/sigma)`` at ``y = U``. A non-finite bound disables
+ that side (no censoring there).
+"""
+
+import enum
+import math
+from collections.abc import Mapping, Sequence
+from dataclasses import dataclass
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from jax import Array
+from jax.scipy.special import log_ndtr
+from jax.scipy.stats import norm
+
+
+@dataclass(frozen=True)
+class GaussianMeasurement:
+ """Continuous measure with a free normal measurement-error SD (the default)."""
+
+
+@dataclass(frozen=True)
+class ProbitMeasurement:
+ """Binary 0/1 measure with a standard-normal latent error (scale fixed to 1)."""
+
+
+@dataclass(frozen=True)
+class TobitMeasurement:
+ """Censored-normal measure (not truncated, selected, or zero-inflated).
+
+ `lower` / `upper` are the known censoring bounds; `None` disables that side.
+ At least one bound must be finite. Observations equal to a bound are treated as
+ censored (tail mass); interior observations use the normal density.
+ """
+
+ lower: float | None = 0.0
+ """Lower censoring bound, or `None` for no lower censoring."""
+ upper: float | None = None
+ """Upper censoring bound, or `None` for no upper censoring."""
+
+ def __post_init__(self) -> None: # noqa: D105
+ if self.lower is None and self.upper is None:
+ msg = "TobitMeasurement needs at least one finite censoring bound."
+ raise ValueError(msg)
+ if (
+ self.lower is not None
+ and self.upper is not None
+ and self.lower >= self.upper
+ ):
+ msg = (
+ f"TobitMeasurement lower bound ({self.lower}) must be strictly below "
+ f"the upper bound ({self.upper})."
+ )
+ raise ValueError(msg)
+
+
+# A measurement's observation model, attached per variable on `ModelSpec`.
+MeasurementModel = GaussianMeasurement | ProbitMeasurement | TobitMeasurement
+
+
+class MeasurementFamily(enum.IntEnum):
+ """Observation model attached to a measurement variable.
+
+ Integer-valued so the code can be stored in a JAX array aligned with the
+ measurement system and compared inside traced/jitted code. A distinct name from
+ `MeasurementType` (which marks state vs endogenous-factor augmented periods).
+ """
+
+ GAUSSIAN = 0
+ PROBIT = 1
+ TOBIT = 2
+
+
+def resolve_measurement_family(
+ model: MeasurementModel,
+) -> tuple[MeasurementFamily, float, float]:
+ """Map a public `MeasurementModel` to its internal `(family, lower, upper)`.
+
+ `lower` / `upper` are the censoring bounds threaded to `measurement_loglik`;
+ non-Tobit families and open Tobit sides use `-inf` / `+inf`.
+ """
+ if isinstance(model, ProbitMeasurement):
+ return MeasurementFamily.PROBIT, -math.inf, math.inf
+ if isinstance(model, TobitMeasurement):
+ lower = -math.inf if model.lower is None else float(model.lower)
+ upper = math.inf if model.upper is None else float(model.upper)
+ return MeasurementFamily.TOBIT, lower, upper
+ return MeasurementFamily.GAUSSIAN, -math.inf, math.inf
+
+
+def measurement_family_arrays(
+ measurement_models: Mapping[str, MeasurementModel],
+ measure_names: Sequence[str],
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Return `(family_codes, lowers, uppers)` aligned to `measure_names`.
+
+ A name absent from `measurement_models` resolves to Gaussian. The arrays line
+ up row-for-row with a measurement system's loadings / SD arrays so the shared
+ kernel can be vmapped over them. Each code is a `MeasurementFamily` integer.
+ """
+ resolved = [
+ resolve_measurement_family(measurement_models.get(name, GaussianMeasurement()))
+ for name in measure_names
+ ]
+ codes = np.array([int(family) for family, _lo, _hi in resolved], dtype=np.int64)
+ lowers = np.array([lo for _f, lo, _hi in resolved], dtype=np.float64)
+ uppers = np.array([hi for _f, _lo, hi in resolved], dtype=np.float64)
+ return codes, lowers, uppers
+
+
+def measurement_loglik(
+ y: Array,
+ eta: Array,
+ sigma: Array,
+ family: Array,
+ lower: Array,
+ upper: Array,
+) -> Array:
+ """Return the log contribution of one measurement under its family.
+
+ Args:
+ y: Observed measurement value (0/1 for probit; a censored value for Tobit).
+ eta: Linear predictor `c + x'beta + lambda'theta`.
+ sigma: Measurement scale (ignored for probit, which fixes it to 1).
+ family: A `MeasurementFamily` integer code.
+ lower: Tobit lower censoring bound (`-inf` to disable).
+ upper: Tobit upper censoring bound (`+inf` to disable).
+
+ Return:
+ The scalar log density (Gaussian/Tobit interior) or log probability
+ (probit, Tobit tail) of the measurement.
+
+ """
+ gaussian = norm.logpdf(y, loc=eta, scale=sigma)
+ probit = log_ndtr((2.0 * y - 1.0) * eta)
+ tobit = _tobit_loglik(y, eta, sigma, lower, upper)
+ return jnp.where(
+ family == int(MeasurementFamily.PROBIT),
+ probit,
+ jnp.where(family == int(MeasurementFamily.TOBIT), tobit, gaussian),
+ )
+
+
+def _tobit_loglik(
+ y: Array, eta: Array, sigma: Array, lower: Array, upper: Array
+) -> Array:
+ """Censored-normal log contribution with safe-gradient handling of inf bounds.
+
+ A non-finite bound is replaced by a finite sentinel before any arithmetic so the
+ masked-out tail term carries no `inf`/`NaN` into `jnp.where` (whose gradient
+ would otherwise be poisoned), then masked out by the finiteness check.
+ """
+ interior = norm.logpdf(y, loc=eta, scale=sigma)
+
+ lower_finite = jnp.isfinite(lower)
+ upper_finite = jnp.isfinite(upper)
+ safe_lower = jnp.where(lower_finite, lower, 0.0)
+ safe_upper = jnp.where(upper_finite, upper, 0.0)
+
+ left = log_ndtr((safe_lower - eta) / sigma)
+ right = log_ndtr((eta - safe_upper) / sigma)
+
+ at_lower = lower_finite & (y <= lower)
+ at_upper = upper_finite & (y >= upper)
+ return jnp.where(at_lower, left, jnp.where(at_upper, right, interior))
+
+
+# Vmapped over a 1-D measurement vector (all six arguments share axis 0). Use this
+# form INSIDE another jitted/vmapped function (e.g. the AF integrand); it is not
+# itself jitted, so it composes without nesting `jax.jit`.
+measurement_loglik_vec = jax.vmap(measurement_loglik, in_axes=(0, 0, 0, 0, 0, 0))
+
+# Standalone pre-jitted batch form for callers not already inside a jit.
+measurement_loglik_batch = jax.jit(measurement_loglik_vec)
diff --git a/src/skillmodels/common/model_spec.py b/src/skillmodels/common/model_spec.py
index 2b20354b..b9f57d9a 100644
--- a/src/skillmodels/common/model_spec.py
+++ b/src/skillmodels/common/model_spec.py
@@ -9,15 +9,50 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field, replace
from types import MappingProxyType
-from typing import Any, Self
+from typing import Any, Literal, Self
from skillmodels._beartype_conf import MODEL_SPEC_CONF, beartype_init
+from skillmodels.common.measurement_models import (
+ GaussianMeasurement,
+ MeasurementModel,
+ ProbitMeasurement,
+ TobitMeasurement,
+)
from skillmodels.common.types import (
Normalizations,
ensure_containers_are_immutable,
)
+def _parse_measurement_models(
+ spec: Mapping[str, Mapping[str, Any]] | None,
+) -> dict[str, MeasurementModel]:
+ """Build `MeasurementModel` instances from a `from_dict` measurement_models map.
+
+ Each entry is `{family: "gaussian" | "probit" | "tobit", lower?, upper?}`.
+ """
+ if not spec:
+ return {}
+ models: dict[str, MeasurementModel] = {}
+ for name, cfg in spec.items():
+ family = cfg.get("family", "gaussian")
+ if family == "gaussian":
+ models[name] = GaussianMeasurement()
+ elif family == "probit":
+ models[name] = ProbitMeasurement()
+ elif family == "tobit":
+ models[name] = TobitMeasurement(
+ lower=cfg.get("lower", 0.0), upper=cfg.get("upper")
+ )
+ else:
+ msg = (
+ f"measurement_models['{name}'] has unknown family '{family}'; "
+ "expected 'gaussian', 'probit', or 'tobit'."
+ )
+ raise ValueError(msg)
+ return models
+
+
@beartype_init(MODEL_SPEC_CONF)
@dataclass(frozen=True)
class CorrectionSpec:
@@ -110,6 +145,16 @@ class FactorSpec:
of the initial step. The transition function must not depend on the
factor's own lag.
"""
+ af_state_role: Literal["dynamic", "static_persistent"] = "dynamic"
+ """AF calendar role of a non-endogenous state factor.
+
+ `"static_persistent"` marks a time-invariant factor (e.g. MC/MN) whose
+ period-0 measurement density is re-applied as an AF importance factor at
+ every transition step, because the same latent value re-enters the
+ investment/production equations at every period. Declared explicitly
+ rather than inferred from an identity transition (which may be pinned via
+ `fixed_params`). Ignored by CHS/AMN and for endogenous factors.
+ """
def with_transition_function(self, func: str | Callable) -> Self:
"""Return a new FactorSpec with the given transition function."""
@@ -161,6 +206,11 @@ class ModelSpec:
"""Anchoring specification."""
n_mixtures: int = 1
"""Number of Gaussian-mixture components in the latent-factor distribution."""
+ measurement_models: MappingProxyType[str, MeasurementModel] = field(
+ default_factory=lambda: MappingProxyType({})
+ )
+ """Per-variable observation model (probit / Tobit / Gaussian). A measurement
+ omitted from this mapping is Gaussian, so existing specs are unchanged."""
def __init__(
self,
@@ -170,6 +220,7 @@ def __init__(
stagemap: tuple[int, ...] | None = None,
anchoring: AnchoringSpec | None = None,
n_mixtures: int = 1,
+ measurement_models: Mapping[str, MeasurementModel] | None = None,
) -> None:
"""Create ModelSpec, wrapping factors dict in MappingProxyType."""
object.__setattr__(self, "_factors", ensure_containers_are_immutable(factors))
@@ -178,6 +229,37 @@ def __init__(
object.__setattr__(self, "stagemap", stagemap)
object.__setattr__(self, "anchoring", anchoring)
object.__setattr__(self, "n_mixtures", n_mixtures)
+ self._set_measurement_models(measurement_models)
+
+ def _set_measurement_models(
+ self, measurement_models: Mapping[str, MeasurementModel] | None
+ ) -> None:
+ """Validate measurement-model keys against the model's measures and store."""
+ models = dict(measurement_models or {})
+ known = self._measurement_names()
+ unknown = sorted(name for name in models if name not in known)
+ if unknown:
+ msg = (
+ f"measurement_models references {unknown}, which is not a measurement "
+ f"variable in the model. Known measurements: {sorted(known)}."
+ )
+ raise ValueError(msg)
+ object.__setattr__(
+ self, "measurement_models", ensure_containers_are_immutable(models)
+ )
+
+ def _measurement_names(self) -> set[str]:
+ """Return every measurement-variable name across all factors and periods."""
+ return {
+ meas
+ for spec in self._factors.values()
+ for period in spec.measurements
+ for meas in period
+ }
+
+ def measurement_model(self, name: str) -> MeasurementModel:
+ """Return the observation model for a measurement (Gaussian if unlisted)."""
+ return self.measurement_models.get(name, GaussianMeasurement())
@classmethod
def from_dict(cls, d: dict[str, Any]) -> Self:
@@ -227,6 +309,7 @@ def from_dict(cls, d: dict[str, Any]) -> Self:
transition_function=spec.get("transition_function"),
has_production_shock=spec.get("has_production_shock", True),
has_initial_distribution=spec.get("has_initial_distribution", True),
+ af_state_role=spec.get("af_state_role", "dynamic"),
correction=correction,
)
@@ -240,6 +323,8 @@ def from_dict(cls, d: dict[str, Any]) -> Self:
observed = tuple(d.get("observed_factors", []))
observed += tuple(i for i in auto_instruments if i not in observed)
+ measurement_models = _parse_measurement_models(d.get("measurement_models"))
+
return cls(
factors=factors,
observed_factors=observed,
@@ -247,6 +332,7 @@ def from_dict(cls, d: dict[str, Any]) -> Self:
stagemap=tuple(stagemap) if stagemap is not None else None,
anchoring=anchoring,
n_mixtures=d.get("n_mixtures", 1),
+ measurement_models=measurement_models,
)
@property
@@ -263,6 +349,9 @@ def _replace(self, **changes: Any) -> Self: # noqa: ANN401
stagemap=changes.get("stagemap", self.stagemap),
anchoring=changes.get("anchoring", self.anchoring),
n_mixtures=changes.get("n_mixtures", self.n_mixtures),
+ measurement_models=changes.get(
+ "measurement_models", self.measurement_models
+ ),
)
def with_transition_functions(
diff --git a/src/skillmodels/common/simulate_data.py b/src/skillmodels/common/simulate_data.py
index 67060b7b..5890ccbb 100644
--- a/src/skillmodels/common/simulate_data.py
+++ b/src/skillmodels/common/simulate_data.py
@@ -12,6 +12,10 @@
from skillmodels._beartype_conf import SIMULATION_CONF
from skillmodels.common.anchoring import anchor_states_df
+from skillmodels.common.measurement_models import (
+ MeasurementFamily,
+ resolve_measurement_family,
+)
from skillmodels.common.model_spec import ModelSpec
from skillmodels.common.params_index import get_params_index
from skillmodels.common.parse_params import create_parsing_info, parse_params
@@ -150,6 +154,10 @@ def simulate_dataset(
policies=policies,
transition_info=processed_model.transition_info,
rng=rng,
+ measurement_families={
+ name: resolve_measurement_family(model)
+ for name, model in model_spec.measurement_models.items()
+ },
)
# Create collapsed versions with user-facing periods
@@ -209,6 +217,7 @@ def _simulate_dataset(
policies: list[dict] | None,
transition_info: TransitionInfo,
rng: np.random.Generator,
+ measurement_families: Mapping[str, tuple[int, float, float]] | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Simulate datasets generated by a latent factor model.
@@ -229,6 +238,9 @@ def _simulate_dataset(
policies: List of policy dictionaries specifying stochastic shocks.
transition_info: Information about transition functions.
rng: NumPy random number generator.
+ measurement_families: Optional `name -> (family_code, lower, upper)` map for
+ non-Gaussian measures. Omitted or all-Gaussian periods draw exactly as
+ before; listed probit/Tobit measures are drawn by their family.
Returns:
observed_data: DataFrame with simulated measurements.
@@ -329,6 +341,10 @@ def _simulate_dataset(
observed_data_by_period = []
for t in range(n_aug_periods):
+ measure_names = list(loadings_df.loc[t].index)
+ families, lowers, uppers = _period_family_arrays(
+ measure_names, measurement_families
+ )
meas = pd.DataFrame(
data=measurements_from_states(
rng=rng,
@@ -337,6 +353,9 @@ def _simulate_dataset(
loadings=loadings_df.loc[t].to_numpy(),
control_params=control_params_df.loc[t].to_numpy(),
sds=meas_sds.loc[t].to_numpy().flatten(),
+ families=families,
+ lowers=lowers,
+ uppers=uppers,
),
columns=loadings_df.loc[t].index,
)
@@ -497,6 +516,9 @@ def measurements_from_states(
loadings: NDArray[np.floating] | Array,
control_params: NDArray[np.floating] | Array,
sds: NDArray[np.floating] | Array,
+ families: NDArray[np.integer] | None = None,
+ lowers: NDArray[np.floating] | None = None,
+ uppers: NDArray[np.floating] | None = None,
) -> NDArray[np.floating] | Array:
"""Generate the variables that would be observed in practice.
@@ -513,6 +535,13 @@ def measurements_from_states(
sds: numpy array of size (n_meas) with the standard deviations
of the measurements. Measurement error is assumed to be independent
across measurements.
+ families: optional `(n_meas,)` `MeasurementFamily` codes. `None` (the
+ default) draws every measurement as Gaussian, byte-identical to the
+ original simulation. When given, each measurement is drawn by its
+ family: Gaussian adds normal noise, probit returns
+ `1{eta + N(0,1) >= 0}`, Tobit clips `eta + N(0, sigma^2)` to its bounds.
+ lowers: optional `(n_meas,)` Tobit lower bounds (`-inf` to disable).
+ uppers: optional `(n_meas,)` Tobit upper bounds (`+inf` to disable).
Returns:
measurements: array of shape (n_obs, n_meas) with measurements.
@@ -520,10 +549,65 @@ def measurements_from_states(
"""
n_meas = loadings.shape[0]
n_obs = len(states)
- epsilon = rng.multivariate_normal([0] * n_meas, np.diag(sds**2), n_obs)
- states_part = np.dot(states, loadings.T)
- control_part = np.dot(controls, control_params.T)
- return states_part + control_part + epsilon
+ eta = np.dot(states, loadings.T) + np.dot(controls, control_params.T)
+ if families is None:
+ epsilon = rng.multivariate_normal([0] * n_meas, np.diag(sds**2), n_obs)
+ return eta + epsilon
+ return _draw_by_family(rng, eta, np.asarray(sds), families, lowers, uppers)
+
+
+def _period_family_arrays(
+ measure_names: list[str],
+ measurement_families: Mapping[str, tuple[int, float, float]] | None,
+) -> tuple[
+ NDArray[np.integer] | None,
+ NDArray[np.floating] | None,
+ NDArray[np.floating] | None,
+]:
+ """Build aligned family/bound arrays for a period, or `None` if all Gaussian.
+
+ Returning `None` for an all-Gaussian period keeps `measurements_from_states` on
+ its original byte-identical Gaussian path.
+ """
+ if not measurement_families:
+ return None, None, None
+ gaussian = int(MeasurementFamily.GAUSSIAN)
+ resolved = [
+ measurement_families.get(name, (gaussian, -np.inf, np.inf))
+ for name in measure_names
+ ]
+ if all(code == gaussian for code, _lo, _hi in resolved):
+ return None, None, None
+ families = np.array([code for code, _lo, _hi in resolved], dtype=int)
+ lowers = np.array([lo for _code, lo, _hi in resolved], dtype=float)
+ uppers = np.array([hi for _code, _lo, hi in resolved], dtype=float)
+ return families, lowers, uppers
+
+
+def _draw_by_family(
+ rng: np.random.Generator,
+ eta: NDArray[np.floating],
+ sds: NDArray[np.floating],
+ families: NDArray[np.integer],
+ lowers: NDArray[np.floating] | None,
+ uppers: NDArray[np.floating] | None,
+) -> NDArray[np.floating]:
+ """Draw each measurement column according to its `MeasurementFamily`."""
+ n_obs, n_meas = eta.shape
+ out = np.empty_like(eta)
+ for j in range(n_meas):
+ family = int(families[j])
+ if family == int(MeasurementFamily.PROBIT):
+ latent = eta[:, j] + rng.standard_normal(n_obs)
+ out[:, j] = (latent >= 0.0).astype(float)
+ elif family == int(MeasurementFamily.TOBIT):
+ latent = eta[:, j] + sds[j] * rng.standard_normal(n_obs)
+ lower = -np.inf if lowers is None else lowers[j]
+ upper = np.inf if uppers is None else uppers[j]
+ out[:, j] = np.clip(latent, lower, upper)
+ else:
+ out[:, j] = eta[:, j] + sds[j] * rng.standard_normal(n_obs)
+ return out
@beartype(conf=SIMULATION_CONF)
diff --git a/tests/test_af_adapter_alignment.py b/tests/test_af_adapter_alignment.py
new file mode 100644
index 00000000..1f73976c
--- /dev/null
+++ b/tests/test_af_adapter_alignment.py
@@ -0,0 +1,51 @@
+"""The AF calendar adapter aligns the whole per-observation payload by individual ID.
+
+The adapter sources mixed-calendar measurement columns on a sorted individual-ID
+intersection, while controls, observed factors, the period-0 conditional distribution,
+and the chain-link payloads are read positionally in input-row order. Unless every
+per-observation array shares one canonical order, a measurement row for one person is
+paired with another person's controls or latent payload (a silent point-estimate change
+on a reordered or unbalanced panel).
+
+`_align_adapter_panel` canonicalises the panel to one ID-sorted order on the adapter
+path, and requires the panel to be balanced (AF aligns periods positionally), raising
+otherwise.
+"""
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from skillmodels.af.estimate import _align_adapter_panel
+
+
+def _panel(ids_per_period: dict[int, list[int]]) -> pd.DataFrame:
+ rows = []
+ for period, ids in ids_per_period.items():
+ for i in ids:
+ rows.append({"id": i, "period": period, "y": float(10 * period + i)})
+ return pd.DataFrame(rows).set_index(["id", "period"])
+
+
+def test_align_adapter_panel_canonicalises_shuffled_order() -> None:
+ rng = np.random.default_rng(0)
+ balanced = _panel({0: [1, 2, 3], 1: [1, 2, 3]})
+ shuffled = balanced.sample(frac=1.0, random_state=np.random.RandomState(1))
+ assert shuffled.index.tolist() != balanced.sort_index().index.tolist()
+
+ aligned = _align_adapter_panel(shuffled, n_periods=2)
+
+ # Canonical (id, period) order regardless of input row order.
+ assert aligned.index.tolist() == balanced.sort_index().index.tolist()
+ # Within each period the rows are id-ascending, so the period slices align.
+ for period in (0, 1):
+ slice_ids = aligned.xs(period, level="period").index.tolist()
+ assert slice_ids == sorted(slice_ids)
+ _ = rng
+
+
+def test_align_adapter_panel_rejects_unbalanced() -> None:
+ # Individual 3 is missing from period 1: AF cannot positionally align the periods.
+ unbalanced = _panel({0: [1, 2, 3], 1: [1, 2]})
+ with pytest.raises(ValueError, match="balanced"):
+ _align_adapter_panel(unbalanced, n_periods=2)
diff --git a/tests/test_af_adapter_inference_guard.py b/tests/test_af_adapter_inference_guard.py
new file mode 100644
index 00000000..28063f12
--- /dev/null
+++ b/tests/test_af_adapter_inference_guard.py
@@ -0,0 +1,103 @@
+"""AF standard errors still reject calendar-adapter models for now.
+
+`compute_af_standard_errors` still reconstructs the single-period measurement layout,
+which is wrong for the source/destination calendar adapter (source-investment params
+live in the next step's result and the per-step parser counts differ). Until it consumes
+the compiled `AFStepLayout`, it must raise on an adapter model rather than return
+standard errors that differentiate the wrong objective. (`get_af_posterior_states` now
+supports adapter models: it scores each state factor against its own correctly-parsed
+indicators and drops the mis-sourced reconstructed-investment indicators -- see
+`test_af_posterior_states.py`.)
+"""
+
+import pandas as pd
+import pytest
+
+from skillmodels.af.inference import compute_af_standard_errors
+from skillmodels.af.step_layout import (
+ fail_if_calendar_adapter_unsupported,
+ fail_if_spearman_unsupported_on_adapter,
+ model_uses_calendar_adapter,
+)
+from skillmodels.af.types import AFEstimationResult
+from skillmodels.common.model_spec import FactorSpec, ModelSpec, Normalizations
+
+
+def _plain_model() -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1",), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ },
+ )
+
+
+def _adapter_model() -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1",), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ "investment": FactorSpec(
+ measurements=(("i1",), ("i1",)),
+ normalizations=Normalizations(
+ loadings=({"i1": 1}, {"i1": 1}), intercepts=({"i1": 0}, {"i1": 0})
+ ),
+ transition_function="linear",
+ is_endogenous=True,
+ has_initial_distribution=False,
+ ),
+ },
+ )
+
+
+def test_model_uses_calendar_adapter_detects_reconstructed_endogenous() -> None:
+ assert model_uses_calendar_adapter(_adapter_model()) is True
+ assert model_uses_calendar_adapter(_plain_model()) is False
+
+
+def test_raiser_is_noop_for_plain_model() -> None:
+ fail_if_calendar_adapter_unsupported(_plain_model(), "feature")
+
+
+def test_raiser_rejects_adapter_model() -> None:
+ with pytest.raises(NotImplementedError, match="calendar adapter"):
+ fail_if_calendar_adapter_unsupported(_adapter_model(), "feature")
+
+
+def _adapter_result() -> AFEstimationResult:
+ return AFEstimationResult(
+ period_results=(),
+ params=pd.DataFrame(),
+ model_spec=_adapter_model(),
+ conditional_distributions=(),
+ success=True,
+ loglikelihood=0.0,
+ )
+
+
+def test_standard_errors_reject_adapter_model() -> None:
+ with pytest.raises(NotImplementedError, match="calendar adapter"):
+ compute_af_standard_errors(_adapter_result(), pd.DataFrame())
+
+
+def test_spearman_guard_rejects_adapter_model() -> None:
+ with pytest.raises(NotImplementedError, match="spearman"):
+ fail_if_spearman_unsupported_on_adapter(_adapter_model(), "spearman")
+
+
+def test_spearman_guard_noop_for_constant_strategy_on_adapter() -> None:
+ fail_if_spearman_unsupported_on_adapter(_adapter_model(), "constant")
+
+
+def test_spearman_guard_noop_for_spearman_on_plain_model() -> None:
+ fail_if_spearman_unsupported_on_adapter(_plain_model(), "spearman")
diff --git a/tests/test_af_estimate.py b/tests/test_af_estimate.py
index 3f2be261..cd7b0a81 100644
--- a/tests/test_af_estimate.py
+++ b/tests/test_af_estimate.py
@@ -14,6 +14,8 @@
import pandas as pd
import pytest
+import skillmodels.af.estimate as est
+import skillmodels.af.initial_period as ip
from skillmodels.af import AFEstimationOptions, estimate_af
from skillmodels.af.likelihood import (
_rebuild_chain_at_period,
@@ -38,6 +40,81 @@
REGRESSION_VAULT = Path(__file__).parent / "regression_vault"
+def test_af_options_bounds_distance_field() -> None:
+ """`bounds_distance` is configurable and defaults to the 0.001 SD floor."""
+ assert AFEstimationOptions().bounds_distance == 0.001
+ assert AFEstimationOptions(bounds_distance=0.01).bounds_distance == 0.01
+
+
+class _StopForTest(Exception): # noqa: N818
+ """Sentinel raised by the spy to abort estimation before optimization."""
+
+
+def test_af_options_bounds_distance_threads_to_template(
+ monkeypatch, model2_af, model2_data
+) -> None:
+ """`AFEstimationOptions.bounds_distance` reaches `create_af_params_template`.
+
+ Spy on the template builder, capture the `bounds_distance` it is called with,
+ and raise before any optimization runs (the template is built first), so the
+ test stays cheap.
+ """
+ captured: dict[str, float] = {}
+
+ def _spy(*args, **kwargs):
+ captured["bounds_distance"] = kwargs["bounds_distance"]
+ raise _StopForTest
+
+ monkeypatch.setattr(ip, "create_af_params_template", _spy)
+ with pytest.raises(_StopForTest):
+ estimate_af(
+ model_spec=model2_af,
+ data=model2_data,
+ options=AFEstimationOptions(
+ n_halton_points=10,
+ n_halton_points_shock=10,
+ start_params_strategy="constant",
+ bounds_distance=0.01,
+ ),
+ )
+ assert captured["bounds_distance"] == 0.01
+
+
+def test_af_options_bounds_distance_survives_amn_start(
+ monkeypatch, model2_af, model2_data
+) -> None:
+ """`bounds_distance` survives the af_options rebuild on the AMN start path.
+
+ With `start_params_strategy="amn"`, `estimate_af` reconstructs the options
+ after running AMN; `bounds_distance` must carry through to the per-period
+ template builder, not silently revert to the default.
+ """
+
+ class _FakeAMN:
+ params = pd.DataFrame({"value": []})
+
+ monkeypatch.setattr(est, "estimate_amn", lambda **_kwargs: _FakeAMN())
+ captured: dict[str, float] = {}
+
+ def _spy(*args, **kwargs):
+ captured["bounds_distance"] = kwargs["bounds_distance"]
+ raise _StopForTest
+
+ monkeypatch.setattr(ip, "create_af_params_template", _spy)
+ with pytest.raises(_StopForTest):
+ estimate_af(
+ model_spec=model2_af,
+ data=model2_data,
+ options=AFEstimationOptions(
+ n_halton_points=10,
+ n_halton_points_shock=10,
+ start_params_strategy="amn",
+ bounds_distance=0.01,
+ ),
+ )
+ assert captured["bounds_distance"] == 0.01
+
+
@pytest.fixture
def model2_data():
"""Load the MODEL2 simulated dataset."""
diff --git a/tests/test_af_layout_guards.py b/tests/test_af_layout_guards.py
new file mode 100644
index 00000000..b2d846da
--- /dev/null
+++ b/tests/test_af_layout_guards.py
@@ -0,0 +1,131 @@
+"""Guards that make the AF calendar adapter fail loudly on unsupported models.
+
+The shared integrand assembles the latent vector as all dynamic-state factors followed
+by all reconstructed-endogenous factors, and the adapter is only defined for endogenous
+factors that are reconstructed (no initial distribution). Two configurations would
+otherwise be estimated silently against the wrong model:
+
+- an endogenous factor that still carries an initial distribution taking the
+ source-investment calendar (it is not a reconstructed investment);
+- a public factor order that interleaves an endogenous factor before a dynamic-state
+ factor, which transposes loading columns relative to the assembled latent vector.
+
+Both must raise rather than estimate the wrong model.
+"""
+
+import pytest
+
+from skillmodels.af.step_layout import AFFactorRole
+from skillmodels.af.transition_period import (
+ _factor_infos_from_spec,
+ _fail_if_endogenous_precedes_state,
+ _fail_if_unsupported_adapter_measurements,
+)
+from skillmodels.common.measurement_models import ProbitMeasurement
+from skillmodels.common.model_spec import FactorSpec, ModelSpec, Normalizations
+
+
+def _spec(*, has_initial_distribution: bool) -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1",), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ "investment": FactorSpec(
+ measurements=(("i1",), ("i1",)),
+ normalizations=Normalizations(
+ loadings=({"i1": 1}, {"i1": 1}), intercepts=({"i1": 0}, {"i1": 0})
+ ),
+ transition_function="linear",
+ is_endogenous=True,
+ has_initial_distribution=has_initial_distribution,
+ ),
+ },
+ )
+
+
+def test_factor_infos_marks_reconstructed_endogenous() -> None:
+ infos = _factor_infos_from_spec(
+ _spec(has_initial_distribution=False), endogenous_factors=("investment",)
+ )
+ roles = {info.name: info.role for info in infos}
+ assert roles["investment"] == AFFactorRole.ENDOGENOUS
+ assert roles["skills"] == AFFactorRole.DYNAMIC
+
+
+def test_factor_infos_rejects_endogenous_with_initial_distribution() -> None:
+ with pytest.raises(ValueError, match="has_initial_distribution"):
+ _factor_infos_from_spec(
+ _spec(has_initial_distribution=True), endogenous_factors=("investment",)
+ )
+
+
+def test_fail_if_endogenous_precedes_state_accepts_state_first() -> None:
+ # State before endogenous: matches the assembled [theta_states, inv_endog] order.
+ _fail_if_endogenous_precedes_state(("skills", "investment"), ("investment",))
+
+
+def test_fail_if_endogenous_precedes_state_rejects_interleaving() -> None:
+ with pytest.raises(ValueError, match="precede"):
+ _fail_if_endogenous_precedes_state(("investment", "skills"), ("investment",))
+
+
+def test_adapter_measurements_accept_plain_gaussian_non_crossloaded() -> None:
+ # The clean CNLSY-style model (Gaussian, one factor per measurement) is accepted.
+ _fail_if_unsupported_adapter_measurements(_spec(has_initial_distribution=False))
+
+
+def test_adapter_measurements_reject_cross_loaded_measurement() -> None:
+ # "shared" loads on both skills and investment -> not yet supported by the adapter.
+ model = ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1", "shared"), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ "investment": FactorSpec(
+ measurements=(("i1", "shared"), ("i1",)),
+ normalizations=Normalizations(
+ loadings=({"i1": 1}, {"i1": 1}), intercepts=({"i1": 0}, {"i1": 0})
+ ),
+ transition_function="linear",
+ is_endogenous=True,
+ has_initial_distribution=False,
+ ),
+ },
+ )
+ with pytest.raises(ValueError, match="cross-load"):
+ _fail_if_unsupported_adapter_measurements(model)
+
+
+def test_adapter_measurements_reject_non_gaussian_family() -> None:
+ model = ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1",), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ "investment": FactorSpec(
+ measurements=(("i1",), ("i1",)),
+ normalizations=Normalizations(
+ loadings=({"i1": 1}, {"i1": 1}), intercepts=({"i1": 0}, {"i1": 0})
+ ),
+ transition_function="linear",
+ is_endogenous=True,
+ has_initial_distribution=False,
+ ),
+ },
+ measurement_models={"i1": ProbitMeasurement()},
+ )
+ with pytest.raises(ValueError, match="Gaussian"):
+ _fail_if_unsupported_adapter_measurements(model)
diff --git a/tests/test_af_limited_measurements.py b/tests/test_af_limited_measurements.py
new file mode 100644
index 00000000..ea7a973b
--- /dev/null
+++ b/tests/test_af_limited_measurements.py
@@ -0,0 +1,142 @@
+"""AF probit/Tobit measurement support: hand-computed likelihood acceptance tests.
+
+With fixed Halton nodes, replacing a Gaussian measure by a probit (or Tobit)
+measure must equal a manually node-weighted likelihood -- there is no extra
+approximation beyond the existing Halton integration. These tests pin that against
+plain-numpy/scipy references.
+"""
+
+import jax.numpy as jnp
+import numpy as np
+from scipy.stats import norm
+
+from skillmodels.af.likelihood import af_per_obs_loglike_initial
+from skillmodels.common.measurement_models import MeasurementFamily
+
+_INF = float("inf")
+
+
+def test_initial_probit_measure_matches_node_weighted_reference() -> None:
+ """A binary probit measure at period 0 equals a hand node-weighted likelihood.
+
+ Single latent factor, single mixture component, no observed factors (so the
+ unconditional integrand runs). One Gaussian measure plus one probit measure;
+ the probit contributes ``log Phi((2y-1) eta)`` at each Halton node.
+ """
+ n_factors = 1
+ n_latent = 1
+ n_components = 1
+ n_measures = 2
+ n_controls = 1
+
+ mu = 0.5
+ chol = 1.1
+ control_params = [0.1, -0.2]
+ loadings = [1.0, 0.8]
+ meas_sds = [0.5, 1.0] # probit sd is ignored by the kernel
+
+ params = jnp.array(
+ [1.0, mu, chol, *control_params, *loadings, *meas_sds],
+ )
+ loading_mask = jnp.array([[True], [True]])
+
+ n_obs = 4
+ rng = np.random.default_rng(7)
+ y_gauss = rng.normal(0, 1, n_obs)
+ y_probit = (rng.random(n_obs) < 0.5).astype(float)
+ measurements = jnp.asarray(np.column_stack([y_gauss, y_probit]))
+ controls = jnp.asarray(rng.normal(0, 1, (n_obs, n_controls)))
+
+ raw_nodes = np.array([-1.5, -0.5, 0.5, 1.5])
+ node_w = np.exp(-0.5 * raw_nodes**2)
+ node_w = node_w / node_w.sum()
+ nodes = jnp.asarray(raw_nodes.reshape(-1, 1))
+ weights = jnp.asarray(node_w)
+
+ families = jnp.array(
+ [int(MeasurementFamily.GAUSSIAN), int(MeasurementFamily.PROBIT)]
+ )
+ lowers = jnp.array([-_INF, -_INF])
+ uppers = jnp.array([_INF, _INF])
+
+ per_obs = np.asarray(
+ af_per_obs_loglike_initial(
+ params,
+ n_factors=n_factors,
+ n_mixture_components=n_components,
+ n_measures=n_measures,
+ n_controls=n_controls,
+ measurements=measurements,
+ controls=controls,
+ loading_mask=loading_mask,
+ nodes=nodes,
+ weights=weights,
+ stability_floor=0.0,
+ n_latent_factors=n_latent,
+ measurement_families=families,
+ measurement_lowers=lowers,
+ measurement_uppers=uppers,
+ )
+ )
+
+ control_arr = np.array(control_params).reshape(n_measures, n_controls)
+ load = np.array(loadings)
+ raw = np.asarray(raw_nodes)
+ expected = np.empty(n_obs)
+ for i in range(n_obs):
+ ctrl_i = np.asarray(controls[i])
+ eta_const = control_arr @ ctrl_i # (n_measures,)
+ node_contrib = np.empty(len(raw))
+ for q, z in enumerate(raw):
+ theta = mu + chol * z
+ eta = eta_const + load * theta
+ logp_g = norm.logpdf(y_gauss[i], loc=eta[0], scale=meas_sds[0])
+ logp_p = norm.logcdf((2.0 * y_probit[i] - 1.0) * eta[1])
+ node_contrib[q] = np.exp(logp_g + logp_p)
+ expected[i] = np.log(np.dot(node_w, node_contrib))
+
+ np.testing.assert_allclose(per_obs, expected, rtol=1e-6, atol=1e-9)
+
+
+def test_initial_all_gaussian_families_match_default_path() -> None:
+ """Passing explicit all-Gaussian families equals omitting them (parity)."""
+ n_measures = 2
+ params = jnp.array([1.0, 0.3, 1.0, 0.0, 0.0, 1.0, 0.7, 0.5, 0.6])
+ loading_mask = jnp.array([[True], [True]])
+ rng = np.random.default_rng(1)
+ measurements = jnp.asarray(rng.normal(0, 1, (5, n_measures)))
+ controls = jnp.asarray(np.ones((5, 1)))
+ raw = np.array([-1.0, 0.0, 1.0])
+ w = np.exp(-0.5 * raw**2)
+ w = w / w.sum()
+ nodes = jnp.asarray(raw.reshape(-1, 1))
+ weights = jnp.asarray(w)
+
+ def _call(families, lowers, uppers):
+ return np.asarray(
+ af_per_obs_loglike_initial(
+ params,
+ n_factors=1,
+ n_mixture_components=1,
+ n_measures=n_measures,
+ n_controls=1,
+ measurements=measurements,
+ controls=controls,
+ loading_mask=loading_mask,
+ nodes=nodes,
+ weights=weights,
+ stability_floor=0.0,
+ n_latent_factors=1,
+ measurement_families=families,
+ measurement_lowers=lowers,
+ measurement_uppers=uppers,
+ )
+ )
+
+ default = _call(None, None, None)
+ explicit = _call(
+ jnp.array([int(MeasurementFamily.GAUSSIAN)] * n_measures),
+ jnp.array([-_INF] * n_measures),
+ jnp.array([_INF] * n_measures),
+ )
+ np.testing.assert_allclose(default, explicit, rtol=1e-12, atol=1e-12)
diff --git a/tests/test_af_posterior_states.py b/tests/test_af_posterior_states.py
index 3ac6d986..6f4ca971 100644
--- a/tests/test_af_posterior_states.py
+++ b/tests/test_af_posterior_states.py
@@ -9,17 +9,124 @@
import jax
import jax.numpy as jnp
import numpy as np
+import pandas as pd
import pytest
-from skillmodels.af.posterior_states import _compute_posterior_means
+from skillmodels.af import AFEstimationOptions, estimate_af
+from skillmodels.af.posterior_states import (
+ _compute_posterior_means,
+ get_af_posterior_states,
+)
from skillmodels.af.types import (
ConditionalDistribution,
MixtureComponent,
)
+from skillmodels.common.model_spec import FactorSpec, ModelSpec, Normalizations
jax.config.update("jax_enable_x64", True)
+def _estimate_reconstructed_endogenous_model() -> tuple:
+ """Estimate a small reconstructed-endogenous (calendar-adapter) AF model.
+
+ Investment is endogenous with `has_initial_distribution=False`, so the
+ source/destination calendar adapter is active and `get_af_posterior_states`
+ must consume the compiled layout rather than the single-period reconstruction.
+ Returns `(af_result, model, data)`.
+ """
+ rng = np.random.default_rng(20240617)
+ n_obs, n_periods = 250, 3
+ theta = np.zeros((n_obs, n_periods))
+ inv = np.zeros((n_obs, n_periods))
+ income = rng.normal(1.0, 0.5, n_obs)
+ theta[:, 0] = rng.normal(0, 1, n_obs)
+ inv[:, 0] = 0.5 * theta[:, 0] + 0.2 * income + rng.normal(0, 0.25, n_obs)
+ for t in range(n_periods - 1):
+ theta[:, t + 1] = (
+ 0.05 + 0.6 * theta[:, t] + 0.3 * inv[:, t] + rng.normal(0, 0.3, n_obs)
+ )
+ inv[:, t + 1] = (
+ 0.5 * theta[:, t + 1] + 0.2 * income + rng.normal(0, 0.25, n_obs)
+ )
+
+ rows = []
+ for i in range(n_obs):
+ for t in range(n_periods):
+ row = {
+ "caseid": i,
+ "period": t,
+ "s1": theta[i, t] + rng.normal(0, 0.3),
+ "s2": 0.3 + 0.8 * theta[i, t] + rng.normal(0, 0.35),
+ "s3": -0.1 + 1.1 * theta[i, t] + rng.normal(0, 0.4),
+ "income": income[i],
+ }
+ # Investment is measured only at the SOURCE periods (0, 1); it is
+ # reconstructed at the terminal period, so no terminal indicators.
+ if t < n_periods - 1:
+ row["i1"] = inv[i, t] + rng.normal(0, 0.3)
+ row["i2"] = 0.2 + 0.9 * inv[i, t] + rng.normal(0, 0.35)
+ row["i3"] = -0.1 + 1.2 * inv[i, t] + rng.normal(0, 0.4)
+ else:
+ row["i1"] = row["i2"] = row["i3"] = np.nan
+ rows.append(row)
+ data = pd.DataFrame(rows).set_index(["caseid", "period"])
+
+ inv_meas = (("i1", "i2", "i3"), ("i1", "i2", "i3"), ())
+ model = ModelSpec(
+ factors={
+ "skill": FactorSpec(
+ measurements=(("s1", "s2", "s3"),) * n_periods,
+ normalizations=Normalizations(
+ loadings=({"s1": 1},) * n_periods,
+ intercepts=({"s1": 0},) * n_periods,
+ ),
+ transition_function="linear",
+ ),
+ "investment": FactorSpec(
+ measurements=inv_meas,
+ normalizations=Normalizations(
+ loadings=({"i1": 1}, {"i1": 1}, {}),
+ intercepts=({"i1": 0}, {"i1": 0}, {}),
+ ),
+ transition_function="linear",
+ is_endogenous=True,
+ has_initial_distribution=False,
+ ),
+ },
+ observed_factors=("income",),
+ )
+ af_result = estimate_af(
+ model_spec=model,
+ data=data,
+ options=AFEstimationOptions(
+ n_halton_points=30,
+ n_halton_points_shock=15,
+ optimizer_algorithm="scipy_lbfgsb",
+ ),
+ )
+ return af_result, model, data
+
+
+def test_get_af_posterior_states_supports_calendar_adapter() -> None:
+ # An adapter model (reconstructed-endogenous investment) must yield posterior
+ # state means for the state factors, not raise the deferred-feature guard.
+ af_result, model, data = _estimate_reconstructed_endogenous_model()
+
+ out = get_af_posterior_states(af_result=af_result, model_spec=model, data=data)
+
+ states = out["unanchored_states"]["states"]
+ # State factors only (skill); the endogenous investment is not a state coordinate.
+ assert "skill" in states.columns
+ assert "investment" not in states.columns
+ assert np.isfinite(states["skill"].to_numpy()).all()
+
+ # The period-0 posterior skill mean must track the clean skill signal s1.
+ p0 = states[states["period"] == 0].set_index("caseid")["skill"]
+ s1_0 = data.xs(0, level="period")["s1"].reindex(p0.index)
+ corr = np.corrcoef(p0.to_numpy(), s1_0.to_numpy())[0, 1]
+ assert corr > 0.5, f"period-0 skill posterior should track s1 (corr={corr:.2f})"
+
+
def _placeholder_cond_dist(
*,
samples_per_component: tuple,
diff --git a/tests/test_af_sequential_criterion.py b/tests/test_af_sequential_criterion.py
new file mode 100644
index 00000000..edc694f4
--- /dev/null
+++ b/tests/test_af_sequential_criterion.py
@@ -0,0 +1,63 @@
+"""The AF result exposes its objective under an honest name.
+
+The per-period AF objective is a per-observation *mean* log-likelihood criterion, and
+the aggregate is their sum -- a sequential/composite criterion, not a joint sample
+log-likelihood. `AFEstimationResult` exposes that value as `sequential_criterion` with a
+per-period breakdown in `period_mean_criteria`, while `loglikelihood` is retained as an
+equal-valued alias for protocol conformance and back-compat.
+"""
+
+import pandas as pd
+
+from skillmodels.af.types import AFEstimationResult, AFPeriodResult
+from skillmodels.common.model_spec import FactorSpec, ModelSpec, Normalizations
+
+
+def _model() -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("s1",), ("s1",)),
+ normalizations=Normalizations(
+ loadings=({"s1": 1}, {"s1": 1}), intercepts=({"s1": 0}, {"s1": 0})
+ ),
+ transition_function="linear",
+ ),
+ },
+ )
+
+
+def test_sequential_criterion_aggregates_period_means() -> None:
+ period_results = (
+ AFPeriodResult(
+ period=0,
+ params=pd.DataFrame(),
+ loglikelihood=-1.5,
+ success=True,
+ optimize_result=None,
+ ),
+ AFPeriodResult(
+ period=1,
+ params=pd.DataFrame(),
+ loglikelihood=-2.0,
+ success=True,
+ optimize_result=None,
+ ),
+ )
+ means = tuple(pr.loglikelihood for pr in period_results)
+ total = sum(means)
+ result = AFEstimationResult(
+ period_results=period_results,
+ params=pd.DataFrame(),
+ model_spec=_model(),
+ conditional_distributions=(),
+ success=True,
+ loglikelihood=total,
+ sequential_criterion=total,
+ period_mean_criteria=means,
+ )
+
+ assert result.sequential_criterion == total
+ assert result.loglikelihood == result.sequential_criterion
+ assert result.period_mean_criteria == means
+ assert sum(result.period_mean_criteria) == result.sequential_criterion
diff --git a/tests/test_af_step_assembly.py b/tests/test_af_step_assembly.py
new file mode 100644
index 00000000..74c25d84
--- /dev/null
+++ b/tests/test_af_step_assembly.py
@@ -0,0 +1,195 @@
+"""Tests for assembling an AF step's target + importance arrays from a layout.
+
+This is the heart of the calendar fix: for step `s -> d`, the target block must source
+destination skills from period `d` and source investment from period `s` (so age-7
+indicators measure I_0), and the importance block must include every static-persistent
+factor's period-0 measurement rows at *every* step (the dropped-MC/MN fix), with their
+fixed params pulled from the cumulative `HistoricalParams` (period-0 result), not the
+immediately-previous step.
+"""
+
+import numpy as np
+import pandas as pd
+
+from skillmodels.af.step_assembly import assemble_step_arrays
+from skillmodels.af.step_layout import (
+ AFFactorInfo,
+ AFFactorRole,
+ HistoricalParams,
+ compile_af_step_layouts,
+)
+
+
+def _layouts():
+ return compile_af_step_layouts(
+ (
+ AFFactorInfo("skills", AFFactorRole.DYNAMIC, (("sk",), ("sk",), ("sk",))),
+ AFFactorInfo("MC", AFFactorRole.STATIC_PERSISTENT, (("mc",), (), ())),
+ AFFactorInfo(
+ "investment", AFFactorRole.ENDOGENOUS, (("inv",), ("inv",), ())
+ ),
+ ),
+ n_periods=3,
+ )
+
+
+def _frames():
+ # age-7 investment lives at period 0; age-9 at period 1. MC measured at period 0.
+ return {
+ 0: pd.DataFrame(
+ {"sk": [1.0, 2.0], "mc": [7.0, 8.0], "inv": [0.1, 0.2]}, index=[1, 2]
+ ),
+ 1: pd.DataFrame({"sk": [3.0, 4.0], "inv": [0.3, 0.4]}, index=[1, 2]),
+ 2: pd.DataFrame({"sk": [5.0, 6.0]}, index=[1, 2]),
+ }
+
+
+def _historical():
+ rows = [
+ ("loadings", 0, "mc", "MC", 1.5),
+ ("controls", 0, "mc", "constant", 0.0),
+ ("meas_sds", 0, "mc", "-", 0.4),
+ ("loadings", 0, "sk", "skills", 1.0),
+ ("controls", 0, "sk", "constant", 0.0),
+ ("meas_sds", 0, "sk", "-", 0.5),
+ ("loadings", 1, "sk", "skills", 1.0),
+ ("controls", 1, "sk", "constant", 0.0),
+ ("meas_sds", 1, "sk", "-", 0.5),
+ ]
+ idx = pd.MultiIndex.from_tuples(
+ [(r[0], r[1], r[2], r[3]) for r in rows],
+ names=["category", "period", "name1", "name2"],
+ )
+ return HistoricalParams(pd.DataFrame({"value": [r[4] for r in rows]}, index=idx))
+
+
+def test_step_0_to_1_target_sources_investment_from_period0() -> None:
+ arrays = assemble_step_arrays(
+ _layouts()[0],
+ _frames(),
+ ("skills", "MC", "investment"),
+ _historical(),
+ ("constant",),
+ )
+ # Target columns: skills@1 (=[3,4]) and investment@0 (age-7, =[0.1,0.2]).
+ cols = {name: arrays.target_measurements[:, j] for j, name in arrays.target_order}
+ np.testing.assert_array_equal(cols[("sk", 1)], [3.0, 4.0])
+ np.testing.assert_array_equal(cols[("inv", 0)], [0.1, 0.2])
+
+
+def test_step_1_to_2_importance_includes_static_mc_from_period0() -> None:
+ arrays = assemble_step_arrays(
+ _layouts()[1],
+ _frames(),
+ ("skills", "MC", "investment"),
+ _historical(),
+ ("constant",),
+ )
+ imp = {name: arrays.importance_measurements[:, j] for j, name in arrays.imp_order}
+ # F3: MC_0 data is present in the 1->2 importance block (and source skills@1).
+ np.testing.assert_array_equal(imp[("mc", 0)], [7.0, 8.0])
+ np.testing.assert_array_equal(imp[("sk", 1)], [3.0, 4.0])
+ # MC's fixed loading/SD come from the period-0 history.
+ mc_row = [n for _, n in arrays.imp_order].index(("mc", 0))
+ assert arrays.importance_loadings_flat[mc_row] == 1.5
+ assert arrays.importance_meas_sds[mc_row] == 0.4
+
+
+def _frames_with_control():
+ # A non-constant measurement control `x` whose value differs by period, so a row
+ # sourced from the wrong period is detectable.
+ return {
+ 0: pd.DataFrame(
+ {"sk": [1.0, 2.0], "mc": [7.0, 8.0], "inv": [0.1, 0.2], "x": [10.0, 20.0]},
+ index=[1, 2],
+ ),
+ 1: pd.DataFrame(
+ {"sk": [3.0, 4.0], "inv": [0.3, 0.4], "x": [100.0, 200.0]}, index=[1, 2]
+ ),
+ 2: pd.DataFrame({"sk": [5.0, 6.0], "x": [1000.0, 2000.0]}, index=[1, 2]),
+ }
+
+
+def _historical_with_control():
+ rows = [
+ ("loadings", 0, "mc", "MC", 1.5),
+ ("controls", 0, "mc", "x", 0.5),
+ ("meas_sds", 0, "mc", "-", 0.4),
+ ("loadings", 1, "sk", "skills", 1.0),
+ ("controls", 1, "sk", "x", 2.0),
+ ("meas_sds", 1, "sk", "-", 0.5),
+ ]
+ idx = pd.MultiIndex.from_tuples(
+ [(r[0], r[1], r[2], r[3]) for r in rows],
+ names=["category", "period", "name1", "name2"],
+ )
+ return HistoricalParams(pd.DataFrame({"value": [r[4] for r in rows]}, index=idx))
+
+
+def test_importance_control_contrib_sources_each_row_from_its_control_period() -> None:
+ # 1->2 importance block: MC@0 (control_period 0) + source skills@1 (period 1).
+ # Each row's control contribution must use that row's OWN period's control data and
+ # fixed params, not one shared source-period matrix.
+ arrays = assemble_step_arrays(
+ _layouts()[1],
+ _frames_with_control(),
+ ("skills", "MC", "investment"),
+ _historical_with_control(),
+ ("x",),
+ )
+ rows = [n for _, n in arrays.imp_order]
+ mc_row = rows.index(("mc", 0))
+ sk_row = rows.index(("sk", 1))
+ # MC_0: x@period0 ([10,20]) * 0.5; skills_1: x@period1 ([100,200]) * 2.0.
+ np.testing.assert_allclose(
+ arrays.importance_control_contrib[:, mc_row], [5.0, 10.0]
+ )
+ np.testing.assert_allclose(
+ arrays.importance_control_contrib[:, sk_row], [200.0, 400.0]
+ )
+
+
+def _frames_with_string_ids():
+ # Same data as `_frames`, but individuals keyed by non-integer (string) IDs.
+ return {
+ 0: pd.DataFrame(
+ {"sk": [1.0, 2.0], "mc": [7.0, 8.0], "inv": [0.1, 0.2]}, index=["a", "b"]
+ ),
+ 1: pd.DataFrame({"sk": [3.0, 4.0], "inv": [0.3, 0.4]}, index=["a", "b"]),
+ 2: pd.DataFrame({"sk": [5.0, 6.0]}, index=["a", "b"]),
+ }
+
+
+def test_assemble_preserves_non_integer_ids() -> None:
+ # The adapter's contract is ID-indexed alignment, not numeric IDs: string/UUID
+ # individual identifiers must align rather than be coerced or rejected.
+ arrays = assemble_step_arrays(
+ _layouts()[0],
+ _frames_with_string_ids(),
+ ("skills", "MC", "investment"),
+ _historical(),
+ ("constant",),
+ )
+ assert list(arrays.ids) == ["a", "b"]
+ cols = {name: arrays.target_measurements[:, j] for j, name in arrays.target_order}
+ np.testing.assert_array_equal(cols[("inv", 0)], [0.1, 0.2])
+
+
+def test_target_controls_source_each_row_from_its_control_period() -> None:
+ # 1->2 target block: skills@2 (control_period 2) + source investment@1 (period 1).
+ # The per-row control-data tensor must read each row's own period's control values.
+ arrays = assemble_step_arrays(
+ _layouts()[1],
+ _frames_with_control(),
+ ("skills", "MC", "investment"),
+ _historical_with_control(),
+ ("x",),
+ )
+ order = [n for _, n in arrays.target_order]
+ sk_col = order.index(("sk", 2))
+ inv_col = order.index(("inv", 1))
+ # target_controls has shape (n_ids, n_target, n_controls); single control `x`.
+ # skills@2 reads x from period 2 ([1000,2000]); investment@1 reads x from period 1
+ # ([100,200]) -- each row sourced from its own control_period.
+ np.testing.assert_allclose(arrays.target_controls[:, sk_col, 0], [1000.0, 2000.0])
+ np.testing.assert_allclose(arrays.target_controls[:, inv_col, 0], [100.0, 200.0])
diff --git a/tests/test_af_step_compile.py b/tests/test_af_step_compile.py
new file mode 100644
index 00000000..75a81c46
--- /dev/null
+++ b/tests/test_af_step_compile.py
@@ -0,0 +1,89 @@
+"""Tests for AF step param-index compilation and the cumulative param registry.
+
+A mixed-calendar target block (destination skills at `d`, source investment at `s`)
+must still be emitted in the flat parser's global category order -- all controls, then
+all loadings, then all measurement SDs -- while each row keeps its true `param_period`.
+Importance terms read fixed values for period-0 static factors (MC/MN) that the
+immediately-previous step result does not contain, so a cumulative `HistoricalParams`
+registry keyed by the full MultiIndex is required.
+"""
+
+import pandas as pd
+
+from skillmodels.af.step_layout import (
+ AFFactorInfo,
+ AFFactorRole,
+ HistoricalParams,
+ compile_af_step_layouts,
+ compile_target_measurement_index,
+)
+
+
+def _cnlsy_factor_infos() -> tuple[AFFactorInfo, ...]:
+ return (
+ AFFactorInfo("skills", AFFactorRole.DYNAMIC, (("sk_a", "sk_b"),) * 3),
+ AFFactorInfo("MC", AFFactorRole.STATIC_PERSISTENT, (("mc_1", "mc_2"), (), ())),
+ AFFactorInfo("MN", AFFactorRole.STATIC_PERSISTENT, (("mn_1",), (), ())),
+ AFFactorInfo(
+ "investment",
+ AFFactorRole.ENDOGENOUS,
+ (("inv_a", "inv_b"), ("inv_a", "inv_b"), ()),
+ ),
+ )
+
+
+def _params_df(rows: list[tuple[str, int, str, str, float]]) -> pd.DataFrame:
+ idx = pd.MultiIndex.from_tuples(
+ [(r[0], r[1], r[2], r[3]) for r in rows],
+ names=["category", "period", "name1", "name2"],
+ )
+ return pd.DataFrame({"value": [r[4] for r in rows]}, index=idx)
+
+
+def test_target_measurement_index_uses_global_category_order() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[0]
+ index = compile_target_measurement_index(layout, controls=("constant",))
+ categories = [tup[0] for tup in index]
+ # 4 target measures (2 skills + 2 inv): all controls, then loadings, then sds.
+ assert categories == ["controls"] * 4 + ["loadings"] * 4 + ["meas_sds"] * 4
+
+
+def test_target_measurement_index_keeps_true_param_periods() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[0]
+ index = set(compile_target_measurement_index(layout, controls=("constant",)))
+ # destination skills indexed at period 1; source investment stays at period 0.
+ assert ("loadings", 1, "sk_a", "skills") in index
+ assert ("loadings", 0, "inv_a", "investment") in index
+ assert ("controls", 0, "inv_a", "constant") in index
+ assert ("meas_sds", 1, "sk_b", "-") in index
+
+
+def test_historical_params_reaches_period0_static_factor() -> None:
+ initial = _params_df(
+ [
+ ("loadings", 0, "mc_1", "MC", 1.3),
+ ("loadings", 0, "sk_a", "skills", 1.0),
+ ]
+ )
+ step01 = _params_df(
+ [
+ ("loadings", 1, "sk_a", "skills", 0.9),
+ ("loadings", 0, "inv_a", "investment", 0.7),
+ ]
+ )
+ hist = HistoricalParams.from_param_frames([initial, step01])
+ # MC_0 is reachable from the cumulative registry...
+ assert hist.value("loadings", 0, "mc_1", "MC") == 1.3
+ # ...but NOT from the immediately-previous step result alone (the AF1 defect).
+ assert ("loadings", 0, "mc_1", "MC") not in step01.index
+
+
+def test_historical_params_rejects_duplicate_index() -> None:
+ initial = _params_df([("loadings", 0, "sk_a", "skills", 1.0)])
+ dup = _params_df([("loadings", 0, "sk_a", "skills", 2.0)])
+ try:
+ HistoricalParams.from_param_frames([initial, dup])
+ except ValueError:
+ return
+ msg = "Expected duplicate-index ValueError"
+ raise AssertionError(msg)
diff --git a/tests/test_af_step_data.py b/tests/test_af_step_data.py
new file mode 100644
index 00000000..34fa62ec
--- /dev/null
+++ b/tests/test_af_step_data.py
@@ -0,0 +1,107 @@
+"""Tests for ID-aligned assembly of an AF step's measurement arrays.
+
+A mixed-calendar step block sources columns from different calendar periods
+(destination skills at `d`, source investment at `s`). Those per-period frames may have
+different individual orders and samples, so the assembler must join on individual ID
+rather than concatenate positionally (the AF2 defect), and source each term's column
+from its own `data_period`.
+"""
+
+import numpy as np
+import pandas as pd
+
+from skillmodels.af.step_data import (
+ build_block_loading_mask,
+ build_step_measurement_array,
+)
+from skillmodels.af.step_layout import (
+ AFFactorInfo,
+ AFFactorRole,
+ compile_af_step_layouts,
+)
+
+
+def _step_0_to_1_targets():
+ layout = compile_af_step_layouts(
+ (
+ AFFactorInfo("skills", AFFactorRole.DYNAMIC, (("sk_a",), ("sk_a",))),
+ AFFactorInfo("investment", AFFactorRole.ENDOGENOUS, (("inv_a",), ())),
+ ),
+ n_periods=2,
+ )[0]
+ return layout.target_terms()
+
+
+def test_step_measurement_array_preserves_non_integer_ids() -> None:
+ # The assembler joins on individual ID, which must hold for string/UUID IDs, not
+ # only CNLSY-style numeric case IDs.
+ frame0 = pd.DataFrame({"sk_a": [10.0, 11.0], "inv_a": [1.0, 2.0]}, index=["b", "a"])
+ frame1 = pd.DataFrame({"sk_a": [100.0, 200.0]}, index=["a", "b"])
+ ids, values, term_order = build_step_measurement_array(
+ _step_0_to_1_targets(), {0: frame0, 1: frame1}
+ )
+ assert ids.tolist() == ["a", "b"]
+ inv_col = term_order.index(("inv_a", 0))
+ # investment from period 0, ID-aligned: a->2, b->1.
+ np.testing.assert_array_equal(values[:, inv_col], [2.0, 1.0])
+
+
+def test_step_measurement_array_sources_each_term_from_its_data_period() -> None:
+ # Period frames deliberately have DIFFERENT individual orders.
+ frame0 = pd.DataFrame(
+ {"sk_a": [10.0, 11.0, 12.0], "inv_a": [1.0, 2.0, 3.0]}, index=[2, 1, 3]
+ )
+ frame1 = pd.DataFrame({"sk_a": [100.0, 200.0, 300.0]}, index=[3, 1, 2])
+ ids, values, term_order = build_step_measurement_array(
+ _step_0_to_1_targets(), {0: frame0, 1: frame1}
+ )
+
+ assert ids.tolist() == [1, 2, 3]
+ # skills from period 1 (destination), ID-aligned: id1->200, id2->300, id3->100.
+ sk_col = term_order.index(("sk_a", 1))
+ np.testing.assert_array_equal(values[:, sk_col], [200.0, 300.0, 100.0])
+ # investment from period 0 (source), ID-aligned: id1->2, id2->1, id3->3.
+ inv_col = term_order.index(("inv_a", 0))
+ np.testing.assert_array_equal(values[:, inv_col], [2.0, 1.0, 3.0])
+
+
+def test_step_measurement_array_uses_id_intersection_under_attrition() -> None:
+ # Individual 3 is missing from period 1; the step sample is the intersection.
+ frame0 = pd.DataFrame(
+ {"sk_a": [10.0, 11.0, 12.0], "inv_a": [1.0, 2.0, 3.0]}, index=[1, 2, 3]
+ )
+ frame1 = pd.DataFrame({"sk_a": [100.0, 200.0]}, index=[1, 2])
+ ids, values, term_order = build_step_measurement_array(
+ _step_0_to_1_targets(), {0: frame0, 1: frame1}
+ )
+ assert ids.tolist() == [1, 2]
+ inv_col = term_order.index(("inv_a", 0))
+ np.testing.assert_array_equal(values[:, inv_col], [1.0, 2.0])
+
+
+def test_step_measurement_array_does_not_mispair_positionally() -> None:
+ # If assembled positionally, id-2's investment (period 0, row 0 = 99) would be
+ # paired with id-2's skill (period 1, row 0 = 100), which is wrong.
+ frame0 = pd.DataFrame({"sk_a": [0.0, 0.0], "inv_a": [99.0, 7.0]}, index=[2, 1])
+ frame1 = pd.DataFrame({"sk_a": [100.0, 500.0]}, index=[1, 2])
+ ids, values, term_order = build_step_measurement_array(
+ _step_0_to_1_targets(), {0: frame0, 1: frame1}
+ )
+ inv_col = term_order.index(("inv_a", 0))
+ sk_col = term_order.index(("sk_a", 1))
+ # id 1: inv=7 (period0), skill=100 (period1); id 2: inv=99, skill=500.
+ assert ids.tolist() == [1, 2]
+ np.testing.assert_array_equal(values[:, inv_col], [7.0, 99.0])
+ np.testing.assert_array_equal(values[:, sk_col], [100.0, 500.0])
+
+
+def test_block_loading_mask_marks_each_term_factor() -> None:
+ targets = _step_0_to_1_targets() # sk_a -> skills, inv_a -> investment
+ # Latent order is [skills, investment]; the integrand dots loadings against it.
+ mask = build_block_loading_mask(targets, latent_factors=("skills", "investment"))
+ term_order = [(t.measurement, t.eval_node.value) for t in targets]
+ sk_row = term_order.index(("sk_a", "theta_dest"))
+ inv_row = term_order.index(("inv_a", "inv_src"))
+ # skills row loads only on the skills column; investment row only on investment.
+ np.testing.assert_array_equal(mask[sk_row], [True, False])
+ np.testing.assert_array_equal(mask[inv_row], [False, True])
diff --git a/tests/test_af_step_layout.py b/tests/test_af_step_layout.py
new file mode 100644
index 00000000..407588bf
--- /dev/null
+++ b/tests/test_af_step_layout.py
@@ -0,0 +1,175 @@
+"""Tests for the AF source/destination step-layout compiler.
+
+The AF estimator is sequential: step `s -> d = s+1` estimates the transition, the
+source-period investment equation, and a measurement block. Under the contemporaneous
+public `ModelSpec`, an investment indicator declared at calendar period `c` measures
+`I_c`. The compiler re-times these calendar declarations onto AF's sequential steps so
+that, for step `s -> d`:
+
+- the FREE target block scores destination dynamic-state (skill) indicators on
+ `theta_d` and source endogenous (investment) indicators on `I_s`;
+- the FIXED importance block scores source dynamic-state indicators on `theta_s` and
+ every static-persistent factor's period-0 indicators on its time-invariant value.
+
+This reproduces MATLAB `likelihood_01`/`likelihood_12` (target = {skill_{s+1}, inv_s},
+importance = {skill_s, MC_0, MN_0}) while the public spec stays contemporaneous.
+"""
+
+from skillmodels.af.step_layout import (
+ AFEval,
+ AFFactorInfo,
+ AFFactorRole,
+ compile_af_step_layouts,
+)
+
+
+def _cnlsy_factor_infos() -> tuple[AFFactorInfo, ...]:
+ """A minimal 3-period CNLSY-shaped factor set (skills, MC, MN, investment)."""
+ return (
+ AFFactorInfo(
+ name="skills",
+ role=AFFactorRole.DYNAMIC,
+ measurements_by_period=(
+ ("sk_a", "sk_b"),
+ ("sk_a", "sk_b"),
+ ("sk_a", "sk_b"),
+ ),
+ ),
+ AFFactorInfo(
+ name="MC",
+ role=AFFactorRole.STATIC_PERSISTENT,
+ measurements_by_period=(("mc_1", "mc_2"), (), ()),
+ ),
+ AFFactorInfo(
+ name="MN",
+ role=AFFactorRole.STATIC_PERSISTENT,
+ measurements_by_period=(("mn_1",), (), ()),
+ ),
+ AFFactorInfo(
+ name="investment",
+ role=AFFactorRole.ENDOGENOUS,
+ measurements_by_period=(("inv_a", "inv_b"), ("inv_a", "inv_b"), ()),
+ ),
+ )
+
+
+def _term_keys(terms, role):
+ """Set of (measurement, data_period, param_period, eval, free) for one role."""
+ return {
+ (t.measurement, t.data_period, t.param_period, t.eval_node, t.free)
+ for t in terms
+ if t.role == role
+ }
+
+
+def test_compiles_one_layout_per_transition() -> None:
+ layouts = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)
+ assert len(layouts) == 2
+ assert (layouts[0].source_period, layouts[0].destination_period) == (0, 1)
+ assert (layouts[1].source_period, layouts[1].destination_period) == (1, 2)
+
+
+def test_step_0_to_1_target_block() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[0]
+ # Destination skills on theta_d, source investment (age-7, calendar 0) on I_0.
+ assert _term_keys(layout.terms, "target") == {
+ ("sk_a", 1, 1, AFEval.THETA_DEST, True),
+ ("sk_b", 1, 1, AFEval.THETA_DEST, True),
+ ("inv_a", 0, 0, AFEval.INV_SRC, True),
+ ("inv_b", 0, 0, AFEval.INV_SRC, True),
+ }
+
+
+def test_step_0_to_1_importance_block_includes_static_factors() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[0]
+ # Source skills on theta_s plus the static MC_0/MN_0 densities (the F3 fix).
+ assert _term_keys(layout.terms, "importance") == {
+ ("sk_a", 0, 0, AFEval.THETA_SRC, False),
+ ("sk_b", 0, 0, AFEval.THETA_SRC, False),
+ ("mc_1", 0, 0, AFEval.STATIC, False),
+ ("mc_2", 0, 0, AFEval.STATIC, False),
+ ("mn_1", 0, 0, AFEval.STATIC, False),
+ }
+
+
+def test_step_1_to_2_target_uses_age9_investment_on_i1() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[1]
+ # Destination skills on theta_2, source investment (age-9, calendar 1) on I_1.
+ assert _term_keys(layout.terms, "target") == {
+ ("sk_a", 2, 2, AFEval.THETA_DEST, True),
+ ("sk_b", 2, 2, AFEval.THETA_DEST, True),
+ ("inv_a", 1, 1, AFEval.INV_SRC, True),
+ ("inv_b", 1, 1, AFEval.INV_SRC, True),
+ }
+
+
+def test_step_1_to_2_importance_reapplies_period0_static_factors() -> None:
+ layout = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)[1]
+ # Source skills are now calendar 1; MC_0/MN_0 are re-applied unchanged.
+ assert _term_keys(layout.terms, "importance") == {
+ ("sk_a", 1, 1, AFEval.THETA_SRC, False),
+ ("sk_b", 1, 1, AFEval.THETA_SRC, False),
+ ("mc_1", 0, 0, AFEval.STATIC, False),
+ ("mc_2", 0, 0, AFEval.STATIC, False),
+ ("mn_1", 0, 0, AFEval.STATIC, False),
+ }
+
+
+def test_static_persistent_does_not_leak_future_periods() -> None:
+ """A static factor contributes only its period-0 rows to every importance block.
+
+ A static-persistent factor whose declarations span periods 0 and 2 re-applies its
+ period-0 measurement as a fixed importance term at every step, and never lets its
+ period-2 declaration leak forward into an earlier step's importance block.
+ """
+ factor_infos = (
+ AFFactorInfo(
+ name="skills",
+ role=AFFactorRole.DYNAMIC,
+ measurements_by_period=(("sk0",), ("sk1",), ("sk2",)),
+ ),
+ AFFactorInfo(
+ name="MC",
+ role=AFFactorRole.STATIC_PERSISTENT,
+ measurements_by_period=(("mc0",), (), ("mc2",)),
+ ),
+ AFFactorInfo(
+ name="MN",
+ role=AFFactorRole.STATIC_PERSISTENT,
+ measurements_by_period=(("mn0",), (), ()),
+ ),
+ )
+ layouts = compile_af_step_layouts(factor_infos, n_periods=3)
+ step_01_static = {
+ (t.measurement, t.data_period)
+ for t in layouts[0].terms
+ if t.role == "importance" and t.eval_node == AFEval.STATIC
+ }
+ assert ("mc0", 0) in step_01_static
+ assert ("mc2", 2) not in step_01_static
+ assert ("mn0", 0) in step_01_static
+
+ for layout in layouts:
+ step_static = {
+ (t.measurement, t.data_period)
+ for t in layout.terms
+ if t.role == "importance" and t.eval_node == AFEval.STATIC
+ }
+ assert ("mn0", 0) in step_static
+
+
+def test_each_investment_calendar_wave_is_a_target_exactly_once() -> None:
+ layouts = compile_af_step_layouts(_cnlsy_factor_infos(), n_periods=3)
+ inv_targets = [
+ (t.measurement, t.data_period)
+ for layout in layouts
+ for t in layout.terms
+ if t.role == "target" and t.eval_node == AFEval.INV_SRC
+ ]
+ # age-7 (calendar 0) and age-9 (calendar 1) each used once; no calendar-2 wave.
+ assert sorted(inv_targets) == [
+ ("inv_a", 0),
+ ("inv_a", 1),
+ ("inv_b", 0),
+ ("inv_b", 1),
+ ]
diff --git a/tests/test_af_validate_investment_calendar.py b/tests/test_af_validate_investment_calendar.py
index a8bbd335..0d3c301c 100644
--- a/tests/test_af_validate_investment_calendar.py
+++ b/tests/test_af_validate_investment_calendar.py
@@ -1,22 +1,17 @@
-"""Regression tests for the AF endogenous-investment measurement-calendar guard.
+"""Regression tests for the AF endogenous-investment measurement calendar.
-AF reconstructs an endogenous factor (investment) from the PREVIOUS period's
-latent skills: at the (t-1)->t step the generated investment I is a function of
-theta_{t-1}, and it is that same I which the period-t measurement block scores.
-So measurements declared at period t for an endogenous factor measure the
-investment generated from period t-1 (I_{t-1}), not the contemporaneous I_t.
-CHS / AMN read the identical period-t measurements as the contemporaneous I_t (a
-standard latent factor with its own initial distribution). The two are different
-calendars for one ModelSpec, so `validate_af_model` warns when an endogenous
-factor carries measurements (audit F8). Modelling investment as a standard
-non-endogenous factor instead measures the contemporaneous value and emits no
-calendar warning.
+On the calendar-adapter path a reconstructed endogenous factor (investment,
+`is_endogenous=True` and `has_initial_distribution=False`) is contemporaneous: a
+period-c investment indicator measures `I_c`, the same calendar CHS / AMN read.
+The adapter makes the public ModelSpec contemporaneous and the AF step assembler
+re-times the source investment internally, so `validate_af_model` must NOT warn
+that the indicators denote `I_{t-1}` or that estimators disagree on the calendar.
+A user who shifted their data to satisfy such a warning would reintroduce the
+original off-by-one bug.
"""
import warnings
-import pytest
-
from skillmodels.af.validate import validate_af_model
from skillmodels.common.model_spec import (
FactorSpec,
@@ -36,12 +31,22 @@ def _skills_factor() -> FactorSpec:
)
-def test_validate_af_model_warns_on_endogenous_factor_with_measurements() -> None:
- """An endogenous factor with measurements triggers the calendar warning.
+def _lagged_calendar_warned(records: list[warnings.WarningMessage]) -> bool:
+ """Return whether any recorded warning teaches the stale lagged calendar."""
+ return any(
+ "I_{t-1}" in str(r.message)
+ or "different investment calendars" in str(r.message)
+ or "generated from period" in str(r.message)
+ for r in records
+ )
+
+
+def test_no_lagged_calendar_warning_for_reconstructed_endogenous() -> None:
+ """A reconstructed endogenous factor with measurements emits no calendar warning.
- Its period-1 indicators score the investment generated from period-0 skills
- (I_0), not the contemporaneous I_1, so the user is warned that AF and CHS/AMN
- read this ModelSpec on different investment calendars.
+ Under the calendar adapter its period-c indicators score the contemporaneous
+ `I_c` (shared with CHS / AMN), so `validate_af_model` must not claim they
+ denote `I_{t-1}` or that estimators read different investment calendars.
"""
model = ModelSpec(
factors={
@@ -58,8 +63,10 @@ def test_validate_af_model_warns_on_endogenous_factor_with_measurements() -> Non
),
},
)
- with pytest.warns(UserWarning, match="generated from period"):
+ with warnings.catch_warnings(record=True) as records:
+ warnings.simplefilter("always")
validate_af_model(model)
+ assert not _lagged_calendar_warned(records)
def test_validate_af_model_no_calendar_warning_for_standard_investment() -> None:
@@ -82,9 +89,10 @@ def test_validate_af_model_no_calendar_warning_for_standard_investment() -> None
),
},
)
- with warnings.catch_warnings():
- warnings.simplefilter("error", UserWarning)
- assert validate_af_model(model) is None
+ with warnings.catch_warnings(record=True) as records:
+ warnings.simplefilter("always")
+ validate_af_model(model)
+ assert not _lagged_calendar_warned(records)
def test_validate_af_model_no_calendar_warning_for_measurementless_endogenous() -> None:
@@ -101,6 +109,7 @@ def test_validate_af_model_no_calendar_warning_for_measurementless_endogenous()
),
},
)
- with warnings.catch_warnings():
- warnings.simplefilter("error", UserWarning)
- assert validate_af_model(model) is None
+ with warnings.catch_warnings(record=True) as records:
+ warnings.simplefilter("always")
+ validate_af_model(model)
+ assert not _lagged_calendar_warned(records)
diff --git a/tests/test_amn_limited_measurement_rejection.py b/tests/test_amn_limited_measurement_rejection.py
new file mode 100644
index 00000000..cecf2e8f
--- /dev/null
+++ b/tests/test_amn_limited_measurement_rejection.py
@@ -0,0 +1,79 @@
+"""AMN must reject non-Gaussian measurements rather than treat them as continuous.
+
+AMN's mixture-EM and minimum-distance stages recover loadings/SDs from the
+cross-covariance of multi-indicator measurements -- a continuous-Gaussian moment
+map. A probit/Tobit measure routed through it would be silently treated as
+continuous, so `estimate_amn` rejects any non-Gaussian `measurement_models` entry.
+A future working-linear seeding path will translate such measures before AMN runs.
+"""
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from skillmodels.amn.estimate import estimate_amn
+from skillmodels.common.measurement_models import (
+ ProbitMeasurement,
+ TobitMeasurement,
+)
+from skillmodels.common.model_spec import (
+ FactorSpec,
+ ModelSpec,
+ Normalizations,
+)
+
+
+def _model(measurement_models) -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("y1", "y2", "y3"),) * 2,
+ normalizations=Normalizations(
+ loadings=({"y1": 1},) * 2,
+ intercepts=({"y1": 0},) * 2,
+ ),
+ transition_function="linear",
+ ),
+ },
+ measurement_models=measurement_models,
+ )
+
+
+def _tiny_data() -> pd.DataFrame:
+ rng = np.random.default_rng(0)
+ rows = []
+ for caseid in range(20):
+ for period in (0, 1):
+ rows.append(
+ {
+ "caseid": caseid,
+ "period": period,
+ "y1": rng.normal(),
+ "y2": rng.normal(),
+ "y3": rng.normal(),
+ }
+ )
+ return pd.DataFrame(rows).set_index(["caseid", "period"])
+
+
+@pytest.mark.parametrize(
+ "models",
+ [
+ {"y2": ProbitMeasurement()},
+ {"y3": TobitMeasurement(lower=0.0)},
+ ],
+)
+def test_estimate_amn_rejects_non_gaussian_measurements(models) -> None:
+ with pytest.raises(NotImplementedError, match="Gaussian"):
+ estimate_amn(_model(models), _tiny_data())
+
+
+def test_estimate_amn_rejects_non_gaussian_even_for_start_values() -> None:
+ # The seeding path must hand AMN a working-linear (Gaussian) spec; a raw
+ # probit/Tobit spec is rejected even with for_start_values=True.
+ with pytest.raises(NotImplementedError, match="Gaussian"):
+ estimate_amn(
+ _model({"y2": ProbitMeasurement()}),
+ _tiny_data(),
+ for_start_values=True,
+ )
diff --git a/tests/test_constraints.py b/tests/test_constraints.py
index f2961ae0..5fa52218 100644
--- a/tests/test_constraints.py
+++ b/tests/test_constraints.py
@@ -215,6 +215,66 @@ def test_constant_factor_shock_constraints_only_target_existing_rows() -> None:
assert c.loc in index, f"orphan shock constraint {c.loc} not in params index"
+def test_transition_simplex_constraints_only_target_existing_rows() -> None:
+ # A log_ces (CES) production factor combined with an endogenous factor must not
+ # emit a simplex (probability) constraint at the terminal augmented period where
+ # the factor does not transition. With endogenous factors the transition index
+ # stops at aug_periods[:-2], so a naive aug_periods[:-1] loop in
+ # _get_transition_constraints emits an orphan ProbabilityConstraint whose loc is
+ # absent from the params index, tripping the optimagic selector at maximize time.
+ from dataclasses import replace # noqa: PLC0415
+
+ from skillmodels.common.model_spec import CorrectionSpec # noqa: PLC0415
+ from skillmodels.common.params_index import get_params_index # noqa: PLC0415
+ from skillmodels.test_data.model2 import MODEL2 # noqa: PLC0415
+
+ fac3 = MODEL2.factors["fac3"]
+ new_factors = dict(MODEL2.factors) | {
+ "fac3": replace(
+ fac3, is_endogenous=True, correction=CorrectionSpec(instruments=("inv_z",))
+ ),
+ }
+ model = (
+ MODEL2._replace(factors=new_factors)
+ ._replace(stagemap=None)
+ ._replace(observed_factors=("inv_z",))
+ )
+ processed = process_model(model)
+ index = get_params_index(
+ update_info=processed.update_info,
+ labels=processed.labels,
+ dimensions=processed.dimensions,
+ transition_info=processed.transition_info,
+ endogenous_factors_info=processed.endogenous_factors_info,
+ )
+ constraints = get_constraints(
+ update_info=processed.update_info,
+ labels=processed.labels,
+ dimensions=processed.dimensions,
+ anchoring_info=processed.anchoring,
+ normalizations=processed.normalizations,
+ endogenous_factors_info=processed.endogenous_factors_info,
+ bounds_distance=1e-8,
+ )
+ simplex_locs = []
+ for c in constraints:
+ if not isinstance(c, om.ProbabilityConstraint):
+ continue
+ selector = c.selector
+ if not isinstance(selector, functools.partial):
+ continue
+ simplex_locs += [
+ tup
+ for tup in selector.keywords["loc"]
+ if isinstance(tup, tuple) and tup[0] == "transition"
+ ]
+ assert simplex_locs, "the log_ces factor must have its gammas on a simplex"
+ for loc in simplex_locs:
+ assert loc in index, (
+ f"orphan transition simplex constraint {loc} not in params index"
+ )
+
+
def test_get_constraints_pins_instrument_out_of_production() -> None:
# Built-in production transitions enumerate a free coefficient for every
# observed factor, including the excluded instrument; that coefficient must be
@@ -511,7 +571,13 @@ def test_trans_coeff_constraints() -> None:
"type": "probability",
},
]
- calculated = _get_transition_constraints(labels)
+ no_endogenous = EndogenousFactorsInfo(
+ has_endogenous_factors=False,
+ aug_periods_to_aug_period_meas_types=MappingProxyType({}),
+ aug_periods_from_period=lambda period: [period],
+ factor_info=MappingProxyType({}),
+ )
+ calculated = _get_transition_constraints(labels, no_endogenous)
as_dicts = [_to_dict(c) for c in calculated]
assert_list_equal_except_for_order(as_dicts, expected)
diff --git a/tests/test_docs_no_stale_api.py b/tests/test_docs_no_stale_api.py
new file mode 100644
index 00000000..8faef6e2
--- /dev/null
+++ b/tests/test_docs_no_stale_api.py
@@ -0,0 +1,74 @@
+"""Regression guard: documentation must not teach stale or false API/behavior.
+
+Each assertion encodes a fixed Pro-review finding (F1-F9): a phrase the docs must
+no longer contain, or a corrected phrase/field they must now contain. The audience is a
+user copying the docs; a stale claim makes them build the wrong model or call a missing
+API. Keep these checks textual and cheap so they run in any environment.
+"""
+
+from pathlib import Path
+
+DOCS = Path(__file__).resolve().parent.parent / "docs"
+
+
+def _read(rel: str) -> str:
+ return (DOCS / rel).read_text(encoding="utf-8")
+
+
+def test_model_specs_uses_correction_not_is_correction() -> None:
+ text = _read("how_to_guides/model_specs.md")
+ assert "is_correction" not in text
+ assert "correction" in text
+
+
+def test_names_and_concepts_uses_current_option_names() -> None:
+ text = _read("explanations/names_and_concepts.md")
+ assert "n_mixture_components" not in text
+ assert "initialization_strategy" not in text
+ assert "investment_endogeneity" not in text
+ assert "start_params_strategy" in text
+
+
+def test_amn_guide_does_not_claim_chs_is_one_component() -> None:
+ text = _read("how_to_guides/how_to_estimate_amn.md")
+ assert "one mixture component" not in text
+ assert "n_mixtures" in text
+
+
+def test_af_guide_does_not_overclaim_constraint_support() -> None:
+ text = _read("how_to_guides/how_to_estimate_af.md")
+ assert "All optimagic constraint kinds are supported" not in text
+ assert "select_by_loc" in text
+
+
+def test_compare_guide_does_not_claim_identical_point_estimates() -> None:
+ text = _read("how_to_guides/how_to_compare_estimators.md")
+ assert "same point estimate" not in text
+
+
+def test_transition_functions_does_not_say_amn_lacks_custom_transitions() -> None:
+ text = _read("reference_guides/transition_functions.md")
+ assert "not yet with AMN" not in text
+
+
+def test_tutorial_activates_the_af_model_it_describes() -> None:
+ text = _read("getting_started/tutorial.ipynb")
+ assert "af_state_role" in text
+ assert "has_initial_distribution=False" in text
+ assert "is_endogenous=True" in text
+
+
+def test_estimator_prerequisites_reference_page_exists() -> None:
+ text = _read("reference_guides/estimator_prerequisites.md")
+ assert "prerequisite" in text.lower()
+ # A table comparing the three estimators.
+ for estimator in ("CHS", "AF", "AMN"):
+ assert estimator in text
+ # AF non-Gaussian support is initial-period only — must be stated, not blanket.
+ assert "initial" in text.lower()
+
+
+def test_amn_guide_documents_missing_data_and_fixed_param_prerequisites() -> None:
+ text = _read("how_to_guides/how_to_estimate_amn.md")
+ assert "mixture_em_method" in text
+ assert "allow_never_observed_measurements" in text
diff --git a/tests/test_factor_spec_af_state_role.py b/tests/test_factor_spec_af_state_role.py
new file mode 100644
index 00000000..d1597ce7
--- /dev/null
+++ b/tests/test_factor_spec_af_state_role.py
@@ -0,0 +1,71 @@
+"""Tests for the explicit `af_state_role` metadata on `FactorSpec`.
+
+A time-invariant factor (e.g. MC/MN) whose period-0 measurement density must be
+re-applied as an AF importance factor at every step is declared explicitly via
+`af_state_role="static_persistent"`, rather than inferred from an identity transition
+(which may be supplied through `fixed_params` and is therefore not reliably detectable).
+Factors default to `"dynamic"`, so existing specs are unchanged.
+"""
+
+import pytest
+
+from skillmodels.common.model_spec import FactorSpec, ModelSpec, Normalizations
+
+
+def test_factor_spec_defaults_to_dynamic_af_state_role() -> None:
+ spec = FactorSpec(measurements=(("m",), ("m",)), transition_function="linear")
+ assert spec.af_state_role == "dynamic"
+
+
+def test_factor_spec_accepts_static_persistent_role() -> None:
+ spec = FactorSpec(
+ measurements=(("mc_1",), ()),
+ transition_function="linear",
+ af_state_role="static_persistent",
+ )
+ assert spec.af_state_role == "static_persistent"
+
+
+def test_factor_spec_rejects_unknown_af_state_role() -> None:
+ with pytest.raises((ValueError, TypeError)):
+ FactorSpec(
+ measurements=(("m",),),
+ transition_function="linear",
+ # Intentionally invalid: exercises beartype's runtime rejection.
+ af_state_role="bogus", # ty: ignore[invalid-argument-type]
+ )
+
+
+def test_from_dict_parses_af_state_role() -> None:
+ model = ModelSpec.from_dict(
+ {
+ "factors": {
+ "MC": {
+ "measurements": [["mc_1"], []],
+ "normalizations": {
+ "loadings": [{"mc_1": 1}, {}],
+ "intercepts": [{"mc_1": 0}, {}],
+ },
+ "transition_function": "linear",
+ "af_state_role": "static_persistent",
+ },
+ },
+ }
+ )
+ assert model.factors["MC"].af_state_role == "static_persistent"
+
+
+def test_from_dict_defaults_af_state_role_to_dynamic() -> None:
+ model = ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("test_score",), ("test_score",)),
+ normalizations=Normalizations(
+ loadings=({"test_score": 1},) * 2,
+ intercepts=({"test_score": 0},) * 2,
+ ),
+ transition_function="linear",
+ ),
+ },
+ )
+ assert model.factors["skills"].af_state_role == "dynamic"
diff --git a/tests/test_measurement_loglik.py b/tests/test_measurement_loglik.py
new file mode 100644
index 00000000..580fa334
--- /dev/null
+++ b/tests/test_measurement_loglik.py
@@ -0,0 +1,138 @@
+"""Tests for the shared measurement-family log-likelihood kernel.
+
+`measurement_loglik` returns one log density/probability contribution per
+measurement, dispatching on a `MeasurementFamily` code:
+
+- Gaussian: `log N(y; eta, sigma)`.
+- Probit: `log Phi((2y-1) eta)` for `y in {0, 1}` (standard-normal latent error).
+- Tobit: censored normal -- the interior normal density, or the tail mass
+ `Phi((L-eta)/sigma)` / `Phi((eta-U)/sigma)` at a censoring bound.
+
+The kernel is the single source of truth shared by AF estimation, CHS, simulation
+and posterior-state reweighting, so it is validated against SciPy closed forms and
+its JAX gradient against finite differences.
+"""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+from scipy.stats import norm
+
+from skillmodels.common.measurement_models import (
+ MeasurementFamily,
+ measurement_loglik,
+)
+
+jax.config.update("jax_enable_x64", True)
+
+_INF = float("inf")
+
+
+def _ll(y, eta, sigma, family, lower=-_INF, upper=_INF):
+ return float(
+ measurement_loglik(
+ jnp.asarray(float(y)),
+ jnp.asarray(float(eta)),
+ jnp.asarray(float(sigma)),
+ jnp.asarray(int(family)),
+ jnp.asarray(float(lower)),
+ jnp.asarray(float(upper)),
+ )
+ )
+
+
+@pytest.mark.parametrize("eta", [-2.0, 0.0, 1.5])
+@pytest.mark.parametrize("y", [-1.0, 0.7, 3.0])
+def test_gaussian_kernel_matches_scipy(y: float, eta: float) -> None:
+ sigma = 0.8
+ got = _ll(y, eta, sigma, MeasurementFamily.GAUSSIAN)
+ assert got == pytest.approx(norm.logpdf(y, loc=eta, scale=sigma))
+
+
+@pytest.mark.parametrize("eta", [-3.0, -0.5, 0.0, 0.5, 3.0])
+@pytest.mark.parametrize("y", [0.0, 1.0])
+def test_probit_kernel_matches_scipy(y: float, eta: float) -> None:
+ got = _ll(y, eta, 1.0, MeasurementFamily.PROBIT)
+ expected = norm.logcdf((2.0 * y - 1.0) * eta)
+ assert got == pytest.approx(expected)
+
+
+@pytest.mark.parametrize("eta", [-40.0, 40.0])
+@pytest.mark.parametrize("y", [0.0, 1.0])
+def test_probit_kernel_finite_in_tails(y: float, eta: float) -> None:
+ got = _ll(y, eta, 1.0, MeasurementFamily.PROBIT)
+ assert np.isfinite(got)
+ assert got == pytest.approx(norm.logcdf((2.0 * y - 1.0) * eta), abs=1e-6)
+
+
+def test_tobit_interior_matches_scipy() -> None:
+ y, eta, sigma = 2.3, 1.0, 0.9
+ got = _ll(y, eta, sigma, MeasurementFamily.TOBIT, lower=0.0)
+ expected = norm.logpdf(y, loc=eta, scale=sigma)
+ assert got == pytest.approx(expected)
+
+
+def test_tobit_left_censored_matches_scipy() -> None:
+ lower, eta, sigma = 0.0, 1.0, 0.9
+ got = _ll(lower, eta, sigma, MeasurementFamily.TOBIT, lower=lower)
+ expected = norm.logcdf((lower - eta) / sigma)
+ assert got == pytest.approx(expected)
+
+
+def test_tobit_right_censored_matches_scipy() -> None:
+ upper, eta, sigma = 5.0, 4.0, 1.1
+ got = _ll(upper, eta, sigma, MeasurementFamily.TOBIT, lower=0.0, upper=upper)
+ expected = norm.logcdf((eta - upper) / sigma)
+ assert got == pytest.approx(expected)
+
+
+def test_tobit_one_sided_lower_does_not_censor_at_plus_inf() -> None:
+ # With no upper bound, a large interior value uses the normal density, never
+ # the (absent) upper tail mass.
+ y, eta, sigma = 100.0, 1.0, 0.9
+ got = _ll(y, eta, sigma, MeasurementFamily.TOBIT, lower=0.0, upper=_INF)
+ assert got == pytest.approx(norm.logpdf(y, loc=eta, scale=sigma))
+
+
+def test_probit_loading_gradient_matches_finite_difference() -> None:
+ # eta = c + lambda * theta; differentiate the probit contribution wrt lambda.
+ y, theta, sigma = 1.0, 0.6, 1.0
+
+ def ll_of_lambda(lam: jax.Array) -> jax.Array:
+ eta = 0.2 + lam * theta
+ return measurement_loglik(
+ jnp.asarray(y),
+ eta,
+ jnp.asarray(sigma),
+ jnp.asarray(int(MeasurementFamily.PROBIT)),
+ jnp.asarray(-_INF),
+ jnp.asarray(_INF),
+ )
+
+ lam0 = jnp.asarray(0.8)
+ grad = float(jax.grad(ll_of_lambda)(lam0))
+ h = 1e-6
+ fd = (float(ll_of_lambda(lam0 + h)) - float(ll_of_lambda(lam0 - h))) / (2 * h)
+ assert grad == pytest.approx(fd, rel=1e-5, abs=1e-7)
+
+
+def test_tobit_scale_gradient_matches_finite_difference() -> None:
+ # Differentiate a left-censored Tobit contribution wrt sigma.
+ lower, eta = 0.0, 1.0
+
+ def ll_of_sigma(sig: jax.Array) -> jax.Array:
+ return measurement_loglik(
+ jnp.asarray(lower),
+ jnp.asarray(eta),
+ sig,
+ jnp.asarray(int(MeasurementFamily.TOBIT)),
+ jnp.asarray(lower),
+ jnp.asarray(_INF),
+ )
+
+ sig0 = jnp.asarray(0.9)
+ grad = float(jax.grad(ll_of_sigma)(sig0))
+ h = 1e-6
+ fd = (float(ll_of_sigma(sig0 + h)) - float(ll_of_sigma(sig0 - h))) / (2 * h)
+ assert grad == pytest.approx(fd, rel=1e-5, abs=1e-7)
diff --git a/tests/test_measurement_models.py b/tests/test_measurement_models.py
new file mode 100644
index 00000000..cb89c9d7
--- /dev/null
+++ b/tests/test_measurement_models.py
@@ -0,0 +1,171 @@
+"""Tests for the public measurement-family interface on `ModelSpec`.
+
+A measurement variable's observation model (Gaussian / probit / Tobit) is a
+property of the variable, not of the factor it loads on, so it is configured via
+`ModelSpec.measurement_models` (a `name -> MeasurementModel` mapping). A variable
+omitted from the mapping is Gaussian, so existing specs are unchanged. The marker
+classes resolve to the internal `MeasurementFamily` code + censoring bounds that the
+shared likelihood kernel consumes.
+"""
+
+import math
+
+import numpy as np
+import pytest
+
+from skillmodels.common.measurement_models import (
+ GaussianMeasurement,
+ MeasurementFamily,
+ ProbitMeasurement,
+ TobitMeasurement,
+ measurement_family_arrays,
+ resolve_measurement_family,
+)
+from skillmodels.common.model_spec import (
+ FactorSpec,
+ ModelSpec,
+ Normalizations,
+)
+
+
+def _model(measurement_models=None) -> ModelSpec:
+ return ModelSpec(
+ factors={
+ "skills": FactorSpec(
+ measurements=(("test_score", "passed_grade"),) * 2,
+ normalizations=Normalizations(
+ loadings=({"test_score": 1},) * 2,
+ intercepts=({"test_score": 0},) * 2,
+ ),
+ transition_function="translog",
+ ),
+ },
+ measurement_models=measurement_models,
+ )
+
+
+def test_omitted_measurement_models_defaults_to_gaussian() -> None:
+ model = _model()
+ assert model.measurement_model("test_score") == GaussianMeasurement()
+ assert model.measurement_model("passed_grade") == GaussianMeasurement()
+
+
+def test_listed_measurement_models_are_returned() -> None:
+ model = _model({"passed_grade": ProbitMeasurement()})
+ assert model.measurement_model("passed_grade") == ProbitMeasurement()
+ # An unlisted measure is still Gaussian.
+ assert model.measurement_model("test_score") == GaussianMeasurement()
+
+
+def test_measurement_models_mapping_is_immutable() -> None:
+ model = _model({"passed_grade": ProbitMeasurement()})
+ with pytest.raises(TypeError):
+ # Intentional illegal mutation: the mapping must be read-only.
+ model.measurement_models["test_score"] = ProbitMeasurement() # ty: ignore[invalid-assignment]
+
+
+def test_unknown_measurement_name_is_rejected() -> None:
+ with pytest.raises(ValueError, match="not a measurement"):
+ _model({"not_a_real_measure": ProbitMeasurement()})
+
+
+def test_resolve_gaussian_family() -> None:
+ family, lower, upper = resolve_measurement_family(GaussianMeasurement())
+ assert family is MeasurementFamily.GAUSSIAN
+ assert lower == -math.inf
+ assert upper == math.inf
+
+
+def test_resolve_probit_family() -> None:
+ family, lower, upper = resolve_measurement_family(ProbitMeasurement())
+ assert family is MeasurementFamily.PROBIT
+ assert lower == -math.inf
+ assert upper == math.inf
+
+
+def test_resolve_one_sided_tobit_family() -> None:
+ family, lower, upper = resolve_measurement_family(TobitMeasurement(lower=0.0))
+ assert family is MeasurementFamily.TOBIT
+ assert lower == 0.0
+ assert upper == math.inf
+
+
+def test_resolve_double_censored_tobit_family() -> None:
+ family, lower, upper = resolve_measurement_family(
+ TobitMeasurement(lower=0.0, upper=10.0)
+ )
+ assert family is MeasurementFamily.TOBIT
+ assert (lower, upper) == (0.0, 10.0)
+
+
+def test_tobit_rejects_no_bounds() -> None:
+ with pytest.raises(ValueError, match="at least one"):
+ TobitMeasurement(lower=None, upper=None)
+
+
+def test_tobit_rejects_inverted_bounds() -> None:
+ with pytest.raises(ValueError, match=r"lower.*upper"):
+ TobitMeasurement(lower=5.0, upper=1.0)
+
+
+def test_from_dict_parses_measurement_models() -> None:
+ model = ModelSpec.from_dict(
+ {
+ "factors": {
+ "skills": {
+ "measurements": [
+ ["test_score", "passed_grade"],
+ ["test_score", "passed_grade"],
+ ],
+ "normalizations": {
+ "loadings": [{"test_score": 1}, {"test_score": 1}],
+ "intercepts": [{"test_score": 0}, {"test_score": 0}],
+ },
+ "transition_function": "translog",
+ },
+ },
+ "measurement_models": {
+ "passed_grade": {"family": "probit"},
+ "test_score": {"family": "tobit", "lower": 0.0},
+ },
+ }
+ )
+ assert model.measurement_model("passed_grade") == ProbitMeasurement()
+ assert model.measurement_model("test_score") == TobitMeasurement(lower=0.0)
+
+
+def test_from_dict_without_measurement_models_is_all_gaussian() -> None:
+ model = ModelSpec.from_dict(
+ {
+ "factors": {
+ "skills": {
+ "measurements": [["test_score"], ["test_score"]],
+ "normalizations": {
+ "loadings": [{"test_score": 1}, {"test_score": 1}],
+ "intercepts": [{"test_score": 0}, {"test_score": 0}],
+ },
+ "transition_function": "translog",
+ },
+ },
+ }
+ )
+ assert model.measurement_model("test_score") == GaussianMeasurement()
+
+
+def test_measurement_family_arrays_aligns_and_defaults_gaussian() -> None:
+ models = {
+ "passed_grade": ProbitMeasurement(),
+ "amount": TobitMeasurement(lower=0.0, upper=5.0),
+ }
+ names = ["test_score", "passed_grade", "amount"]
+ codes, lowers, uppers = measurement_family_arrays(models, names)
+ np.testing.assert_array_equal(
+ codes,
+ [
+ int(MeasurementFamily.GAUSSIAN),
+ int(MeasurementFamily.PROBIT),
+ int(MeasurementFamily.TOBIT),
+ ],
+ )
+ assert lowers.tolist() == [-math.inf, -math.inf, 0.0]
+ assert uppers.tolist() == [math.inf, math.inf, 5.0]
diff --git a/tests/test_simulate_limited_measurements.py b/tests/test_simulate_limited_measurements.py
new file mode 100644
index 00000000..4af3aab3
--- /dev/null
+++ b/tests/test_simulate_limited_measurements.py
@@ -0,0 +1,117 @@
+"""Tests for limited-dependent-variable measurement simulation.
+
+`measurements_from_states` builds the linear predictor `eta = controls @ b +
+states @ lambda'` and then draws each measurement according to its family:
+
+- Gaussian: `eta + N(0, sigma^2)`.
+- Probit: `1{eta + N(0,1) >= 0}`, i.e. Bernoulli(`Phi(eta)`).
+- Tobit: `clip(eta + N(0, sigma^2), lower, upper)`.
+
+Passing no family arrays keeps the original all-Gaussian path unchanged.
+"""
+
+import math
+
+import numpy as np
+import pytest
+from scipy.stats import norm
+
+from skillmodels.common.measurement_models import MeasurementFamily
+from skillmodels.common.simulate_data import measurements_from_states
+
+_INF = math.inf
+
+
+def _single_measure(loading: float, eta_const: float):
+ """One measurement loading 1:1 on one state plus a constant control."""
+ states = np.ones((1, 1)) # placeholder; overwritten per test
+ del states
+ loadings = np.array([[loading]])
+ control_params = np.array([[eta_const]])
+ return loadings, control_params
+
+
+def test_gaussian_path_unchanged_without_families() -> None:
+ rng = np.random.default_rng(0)
+ n_obs = 2000
+ states = rng.normal(size=(n_obs, 1))
+ controls = np.ones((n_obs, 1))
+ loadings, control_params = _single_measure(1.0, 0.5)
+ out = measurements_from_states(
+ rng, states, controls, loadings, control_params, np.array([0.3])
+ )
+ # eta = 0.5 + state; residual ~ N(0, 0.3): mean(out - eta) ~ 0.
+ eta = 0.5 + states[:, 0]
+ assert np.std(out[:, 0] - eta) == pytest.approx(0.3, abs=0.02)
+
+
+def test_probit_values_are_binary_and_match_frequency() -> None:
+ rng = np.random.default_rng(1)
+ n_obs = 200_000
+ states = np.full((n_obs, 1), 1.0)
+ controls = np.ones((n_obs, 1))
+ loadings, control_params = _single_measure(0.0, 0.7) # eta = 0.7 for everyone
+ out = measurements_from_states(
+ rng,
+ states,
+ controls,
+ loadings,
+ control_params,
+ np.array([1.0]),
+ families=np.array([int(MeasurementFamily.PROBIT)]),
+ lowers=np.array([-_INF]),
+ uppers=np.array([_INF]),
+ )
+ assert set(np.unique(out)) <= {0.0, 1.0}
+ assert out.mean() == pytest.approx(norm.cdf(0.7), abs=0.005)
+
+
+def test_tobit_left_censoring_mass_and_interior() -> None:
+ rng = np.random.default_rng(2)
+ n_obs = 200_000
+ states = np.full((n_obs, 1), 0.0)
+ controls = np.ones((n_obs, 1))
+ sigma = 1.0
+ loadings, control_params = _single_measure(0.0, 0.3) # eta = 0.3
+ out = measurements_from_states(
+ rng,
+ states,
+ controls,
+ loadings,
+ control_params,
+ np.array([sigma]),
+ families=np.array([int(MeasurementFamily.TOBIT)]),
+ lowers=np.array([0.0]),
+ uppers=np.array([_INF]),
+ )[:, 0]
+ # Nothing below the bound; a point mass at exactly 0.
+ assert out.min() >= 0.0
+ censored_frac = float(np.mean(out == 0.0))
+ assert censored_frac == pytest.approx(norm.cdf((0.0 - 0.3) / sigma), abs=0.005)
+ # Interior values are continuous (the censored mass aside).
+ assert float(np.mean(out > 0.0)) == pytest.approx(1.0 - censored_frac, abs=1e-9)
+
+
+def test_simulation_is_reproducible_with_seed() -> None:
+ loadings, control_params = _single_measure(1.0, 0.0)
+ states = np.random.default_rng(7).normal(size=(100, 1))
+ controls = np.ones((100, 1))
+ sds = np.array([1.0])
+ families = np.array([int(MeasurementFamily.PROBIT)])
+ lowers = np.array([-_INF])
+ uppers = np.array([_INF])
+
+ def _draw(seed: int):
+ return measurements_from_states(
+ np.random.default_rng(seed),
+ states,
+ controls,
+ loadings,
+ control_params,
+ sds,
+ families=families,
+ lowers=lowers,
+ uppers=uppers,
+ )
+
+ np.testing.assert_array_equal(_draw(3), _draw(3))