diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 1421e9e..5665d92 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import numpy as np from lcm import categorical from lcm.typing import ( ContinuousState, @@ -39,12 +40,16 @@ class SpousalIncome: married_has_inc: ScalarInt -HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) +# Host array, not a module-level JAX array: a device array here would +# reserve the GPU memory pool at import time in every process that imports +# the model. It is converted to a device array at each indexing site, where +# the value folds into the surrounding compiled function. +HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) def working_hours_value(labor_supply: DiscreteAction) -> FloatND: """Map labor supply choice to annual hours worked.""" - return HOURS_VALUES[labor_supply] + return jnp.asarray(HOURS_VALUES)[labor_supply] def wage( @@ -74,7 +79,7 @@ def income( income = wage * hours^(1 + exp) * int^(-exp) """ - hours = HOURS_VALUES[labor_supply] + hours = jnp.asarray(HOURS_VALUES)[labor_supply] return jnp.where( hours > 0.0, wage diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 381dec6..eeba5bd 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -7,7 +7,6 @@ from lcm import categorical from lcm.typing import ( Age, - BoolND, ContinuousAction, ContinuousState, DiscreteState, @@ -19,6 +18,11 @@ from aca_model.agent.labor_market import LaggedLaborSupply +# Width of the smooth leisure floor, as a fraction of the time endowment. Small enough +# that leisure equals `time_endowment - cost` wherever work costs sit well below the +# endowment; it only bends the map near and beyond the endowment. +_LEISURE_SMOOTHING_FRACTION = 0.01 + @categorical(ordered=False) class PrefType: @@ -44,11 +48,6 @@ class BenchmarkPrefType: type_1: ScalarInt -def positive_leisure(leisure: FloatND) -> BoolND: - """Return True where leisure is strictly positive.""" - return leisure > 0 - - def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND: """Return the equivalence scale for household size adjustment. @@ -69,6 +68,22 @@ def fixed_cost_of_work( ) +def _smooth_leisure_floor( + leisure_available: FloatND, time_endowment: ScalarFloat +) -> FloatND: + """Bend leisure to a strictly positive floor as work costs approach the endowment. + + `softplus(x) = log(1 + e^x)` via `jnp.logaddexp(0, x)`, scaled by a small fraction + of the endowment. Where `leisure_available` is large relative to the smoothing width + the map reduces to `leisure_available` (bulk unchanged); as it falls to zero leisure + bends to `0⁺` — never negative, never a kinked clamp — so the CRRA aggregator never + receives a non-positive base. The smoothing width scales with the endowment, so the + map is scale-invariant. + """ + smoothing = _LEISURE_SMOOTHING_FRACTION * time_endowment + return smoothing * jnp.logaddexp(0.0, leisure_available / smoothing) + + def leisure_canwork_retiree_or_nongroup( working_hours_value: FloatND, good_health: IntND, @@ -94,7 +109,8 @@ def leisure_canwork_retiree_or_nongroup( 0.0, ) - return time_endowment - health_loss - work_loss + leisure_available = time_endowment - health_loss - work_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def leisure_canwork_tied( @@ -112,7 +128,8 @@ def leisure_canwork_tied( work_loss = jnp.where( working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0 ) - return time_endowment - health_loss - work_loss + leisure_available = time_endowment - health_loss - work_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def leisure_forcedout( @@ -122,7 +139,8 @@ def leisure_forcedout( ) -> FloatND: """Compute leisure for forcedout regimes (no work).""" health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - return time_endowment - health_loss + leisure_available = time_endowment - health_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def consumption_equiv( diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index 411de1d..a8c3255 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -14,12 +14,13 @@ from typing import Any from lcm import DiscreteGrid, Regime -from lcm.solvers import DCEGM +from lcm.solvers import BQSEGM, DCEGM from lcm.typing import UserParams from aca_model.baseline.regimes import _nongroup as nongroup from aca_model.baseline.regimes import _retiree as retiree from aca_model.baseline.regimes import _tied as tied +from aca_model.baseline.regimes._bqsegm import build_bqsegm_solver from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, @@ -35,6 +36,11 @@ from aca_model.baseline.regimes._dcegm import build_dcegm_solver from aca_model.config import GridConfig +# BQSEGM is a per-regime (not global) solver: it solves one 1-D +# consumption/savings regime with at most one discrete action, so it attaches +# only to the M1 vertical-slice regime. +_BQSEGM_REGIME = "nongroup_nomc_inelig_canwork" + __all__ = [ "REGIME_SPECS", "RegimeId", @@ -54,7 +60,11 @@ def build_regime( - name: str, grids: Grids, *, dcegm_solver: DCEGM | None = None + name: str, + grids: Grids, + *, + dcegm_solver: DCEGM | None = None, + bqsegm_solver: BQSEGM | None = None, ) -> Regime: """Build a single baseline Regime object for the given regime name.""" if name == "dead": @@ -67,7 +77,7 @@ def build_regime( if builder is None: msg = f"Unknown HIS type: {spec['his']}" raise ValueError(msg) - return builder(name, grids, dcegm_solver=dcegm_solver) + return builder(name, grids, dcegm_solver=dcegm_solver, bqsegm_solver=bqsegm_solver) def build_all_regimes( @@ -98,9 +108,15 @@ def build_all_regimes( consumption_dollars_points=consumption_dollars_points, ) dcegm_solver = build_dcegm_solver(grids) if solver == "dcegm" else None + bqsegm_solver = build_bqsegm_solver(grids) if solver == "bqsegm" else None regimes = {} for name in REGIME_SPECS: - regimes[name] = build_regime(name, grids, dcegm_solver=dcegm_solver) + regimes[name] = build_regime( + name, + grids, + dcegm_solver=dcegm_solver, + bqsegm_solver=bqsegm_solver if name == _BQSEGM_REGIME else None, + ) regimes["dead"] = build_dead_regime(solver=solver) return regimes diff --git a/src/aca_model/baseline/regimes/_bqsegm.py b/src/aca_model/baseline/regimes/_bqsegm.py new file mode 100644 index 0000000..1adb8a1 --- /dev/null +++ b/src/aca_model/baseline/regimes/_bqsegm.py @@ -0,0 +1,56 @@ +"""BQSEGM solver configuration for the ACA M1 vertical-slice regime. + +BQSEGM is the case-piece endogenous-grid solver for a single 1-D +consumption/savings regime whose budget is split by institutional breakpoints +on a derived monotone income quantity. It shares DC-EGM's post-decision +(savings) spec — consumption is recovered from `resources = max(cash_on_hand, +floor)`, the assets laws are in savings form, the borrowing constraint is the +savings grid's lower bound — but solves only one regime with at most one +discrete action, so it attaches per regime rather than globally. The +function-level rewiring is shared with DC-EGM (`build_dcegm_functions`, the +savings-form assets laws in `_common`); this module holds only the solver +config. +""" + +from lcm import IrregSpacedGrid +from lcm.solvers import BQSEGM + +from aca_model.baseline.regimes._common import Grids + + +def build_bqsegm_solver(grids: Grids) -> BQSEGM: + """Build the per-regime BQSEGM configuration. + + The savings grid mirrors DC-EGM's: lower bound 0 (the borrowing constraint + in post-decision form), upper bound the assets span, cubically clustered + toward the constraint. The budget node is `resources` (post-floor + cash-on-hand) and the post-decision function is `savings`, matching the + shared savings-form spec. + """ + n_points = grids.grid_config.n_savings_gridpoints + _fail_if_too_few_savings_gridpoints(n_points) + assets_points = grids.assets.to_jax() + savings_stop = float(assets_points[-1]) - float(assets_points[0]) + savings_grid = IrregSpacedGrid( + points=tuple(savings_stop * (i / (n_points - 1)) ** 3 for i in range(n_points)), + batch_size=grids.grid_config.n_savings_batch_size, + ) + return BQSEGM( + savings_grid=savings_grid, + continuous_state="assets", + budget_target="resources", + post_decision_function="savings", + # Splay the child stochastic-node expectation per the grid config: `0` (the + # default) reads the whole node mesh in one pass on a memory-rich device; a + # positive value loops it in blocks to fit a tighter budget (a CPU run). + stochastic_node_batch_size=grids.grid_config.n_bqsegm_stochastic_node_batch_size, + ) + + +def _fail_if_too_few_savings_gridpoints(n_savings_gridpoints: int) -> None: + if n_savings_gridpoints < 2: + msg = ( + f"n_savings_gridpoints must be >= 2 to form the cubically clustered " + f"BQSEGM savings grid, got {n_savings_gridpoints}." + ) + raise ValueError(msg) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 847b789..74b3cc4 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -42,7 +42,7 @@ from aca_model.environment import pensions, social_security, taxes from aca_model.environment.social_security import ClaimedSS -SolverName = Literal["brute_force", "dcegm"] +SolverName = Literal["brute_force", "dcegm", "bqsegm"] @categorical(ordered=False) @@ -454,14 +454,27 @@ def build_states(spec: RegimeSpec, grids: Grids) -> dict: return states -def build_actions(spec: RegimeSpec, grids: Grids) -> dict: - """Build the action dict for a non-dead regime.""" +def build_actions( + spec: RegimeSpec, + grids: Grids, + *, + drop_buy_private: bool = False, + drop_labor_supply: bool = False, +) -> dict: + """Build the action dict for a non-dead regime. + + The `drop_*` flags fix a discrete action to a single level for the BQSEGM + M1 vertical slice (its case-piece envelope handles at most one discrete + action). The dropped action's former consumers are rebound to the fixed + level at the regime builder, so removing it here is the action side of the + dags remove-and-fix. + """ actions: dict = {} if spec["ss"] == "choose": actions["claim_ss"] = DiscreteGrid(ClaimedSS) - if spec["canwork"] == "canwork": + if spec["canwork"] == "canwork" and not drop_labor_supply: actions["labor_supply"] = DiscreteGrid(LaborSupply) - if spec["his"] == "nongroup" and spec["mc"] == "nomc": + if spec["his"] == "nongroup" and spec["mc"] == "nomc" and not drop_buy_private: actions["buy_private"] = DiscreteGrid(BuyPrivate) actions["consumption_dollars"] = grids.consumption_dollars return actions @@ -583,6 +596,11 @@ def build_model_functions(*, solver: SolverName = "brute_force") -> dict: """ functions: dict = {} if solver == "dcegm": + # DC-EGM solves every living regime, so the solver-contract functions are + # broadcast model-wide. BQSEGM solves only the M1 regime, which carries + # them at regime level (see `_nongroup.build_regime`); broadcasting them + # would force every brute regime to supply the solver's + # `marginal_continuation`. functions |= build_dcegm_functions() functions["total_health_costs"] = health_insurance.total_costs functions["oop_costs"] = health_insurance.oop_with_medicaid @@ -635,12 +653,28 @@ def build_dcegm_functions() -> dict: } +def build_bqsegm_functions() -> dict: + """Build the regime functions the BQSEGM solver-contract requires. + + BQSEGM inverts the Euler equation internally (CRRA from the utility + parameters), so unlike DC-EGM it needs only the savings-form budget + (`resources`) and the post-decision savings node — not + `inverse_marginal_utility`. + """ + return { + "resources": assets_and_income.resources, + "savings": assets_and_income.savings, + } + + def build_model_constraints(*, solver: SolverName = "brute_force") -> dict: """Build the model-level constraints broadcast into every regime. `dead` masks the borrowing constraint — it has no consumption action. Under DC-EGM there is no explicit borrowing constraint: the savings - grid's lower bound enforces it. + grid's lower bound enforces it. BQSEGM enforces it the same way, but only + in the M1 regime it solves, so the constraint stays broadcast for the brute + regimes and the M1 regime masks it (see `_nongroup.build_regime`). """ if solver == "dcegm": return {} @@ -973,7 +1007,7 @@ def _build_per_target_regime_assets( targets use the full `next_assets` with the pension correction. Under DC-EGM both laws take their post-decision (savings) form. """ - if solver == "dcegm": + if solver in ("dcegm", "bqsegm"): living_law = assets_and_income.next_assets_from_savings dead_law = assets_and_income.next_assets_when_dead_from_savings else: diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 151929c..63ad844 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -4,20 +4,22 @@ Already nongroup, so no SSI/Medicaid override needed for HIS transitions. """ +import functools from collections.abc import Callable from lcm import Regime -from lcm.solvers import DCEGM +from lcm.solvers import BQSEGM, DCEGM from lcm.typing import Age, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance +from aca_model.baseline.health_insurance import BuyPrivate from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, RegimeSpec, build_actions, + build_bqsegm_functions, build_common_functions, build_granular_regime_transition, build_pension_functions, @@ -76,11 +78,32 @@ def transition( return transition -def _build_functions(spec: RegimeSpec) -> dict: - """Build functions dict for a nongroup regime.""" +def _fixed_full_time_labor_supply() -> DiscreteAction: + """Labor supply fixed to full-time work for the BQSEGM M1 slice.""" + return LaborSupply.h2000 + + +def _build_functions( + spec: RegimeSpec, *, fix_buy_private: bool = False, fix_labor_supply: bool = False +) -> dict: + """Build functions dict for a nongroup regime. + + The BQSEGM M1 slice fixes both discrete actions to a single level so the + only choice is continuous consumption: + + - `fix_buy_private` binds `buy_private` to `BuyPrivate.yes` in its consumers + (premium, OOP) — the `buy_private == BuyPrivate.yes` arm — leaving the + remaining budget structure untouched. + - `fix_labor_supply` supplies `labor_supply` as a fixed full-time node read + by labor income, AIME accrual, and the lagged-supply transition (which + stays a state, so the cross-regime continuation space is unchanged). + """ can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) + if can_work and fix_labor_supply: + functions["labor_supply"] = _fixed_full_time_labor_supply + functions["ss_benefit"] = select_ss_benefit(spec) # his and crossed_oamc_threshold are fixed params (constants per regime), @@ -93,13 +116,26 @@ def _build_functions(spec: RegimeSpec) -> dict: functions["hic_premium"] = health_insurance.premium_insured else: functions["hic_premium"] = health_insurance.premium_retired + + if has_buy_private and fix_buy_private: + functions["hic_premium"] = functools.partial( + health_insurance.premium, buy_private=BuyPrivate.yes + ) + functions["primary_oop"] = functools.partial( + health_insurance.primary_oop, buy_private=BuyPrivate.yes + ) + functions.update(build_pension_functions(spec)) return functions def build_regime( - name: str, grids: Grids, *, dcegm_solver: DCEGM | None = None + name: str, + grids: Grids, + *, + dcegm_solver: DCEGM | None = None, + bqsegm_solver: BQSEGM | None = None, ) -> Regime: """Build a nongroup regime.""" spec = REGIME_SPECS[name] @@ -112,23 +148,42 @@ def build_regime( transition_func = _make_transition_forcedout(gets_mc, own) states = build_states(spec, grids) - # `borrowing_constraint` is broadcast from the model level. - constraints: dict = {} - if spec["canwork"] == "canwork": - constraints["positive_leisure"] = preferences.positive_leisure - solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver} + egm_solver = dcegm_solver if dcegm_solver is not None else bqsegm_solver + solver_kwargs: dict = {} if egm_solver is None else {"solver": egm_solver} + state_solver = ( + "brute_force" + if egm_solver is None + else ("bqsegm" if bqsegm_solver is not None else "dcegm") + ) + # The BQSEGM M1 slice fixes the discrete actions to a single level so the + # only choice is continuous consumption against the cliffed budget. + fix_for_bqsegm = bqsegm_solver is not None + functions = _build_functions( + spec, fix_buy_private=fix_for_bqsegm, fix_labor_supply=fix_for_bqsegm + ) + constraints: dict = {} + if fix_for_bqsegm: + # BQSEGM solves only this regime, so its solver-contract functions are + # regime-level here rather than broadcast model-wide, and the model-level + # borrowing constraint is masked — BQSEGM enforces it through the savings + # grid's lower bound. + functions = {**functions, **build_bqsegm_functions()} + constraints = {"borrowing_constraint": None} return Regime( transition=build_granular_regime_transition( transition_func=transition_func, target_ids=own.values() ), active=make_active_func(spec), states=states, - state_transitions=build_state_transitions( - spec, solver="brute_force" if dcegm_solver is None else "dcegm" + state_transitions=build_state_transitions(spec, solver=state_solver), + actions=build_actions( + spec, + grids, + drop_buy_private=fix_for_bqsegm, + drop_labor_supply=fix_for_bqsegm, ), - actions=build_actions(spec, grids), - functions=_build_functions(spec), + functions=functions, constraints=constraints, **solver_kwargs, ) diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index 1dd054b..bb9c3ce 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -8,10 +8,9 @@ import jax.numpy as jnp from lcm import Regime -from lcm.solvers import DCEGM +from lcm.solvers import BQSEGM, DCEGM from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance from aca_model.baseline.regimes._common import ( @@ -109,7 +108,11 @@ def _build_functions(spec: RegimeSpec) -> dict: def build_regime( - name: str, grids: Grids, *, dcegm_solver: DCEGM | None = None + name: str, + grids: Grids, + *, + dcegm_solver: DCEGM | None = None, + bqsegm_solver: BQSEGM | None = None, ) -> Regime: """Build a retiree regime.""" spec = REGIME_SPECS[name] @@ -122,23 +125,22 @@ def build_regime( transition_func = _make_transition_forcedout(gets_mc, own, ng) states = build_states(spec, grids) - # `borrowing_constraint` is broadcast from the model level. - constraints: dict = {} - if spec["canwork"] == "canwork": - constraints["positive_leisure"] = preferences.positive_leisure - solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver} + egm_solver = dcegm_solver if dcegm_solver is not None else bqsegm_solver + solver_kwargs: dict = {} if egm_solver is None else {"solver": egm_solver} + state_solver = ( + "brute_force" + if egm_solver is None + else ("bqsegm" if bqsegm_solver is not None else "dcegm") + ) return Regime( transition=build_granular_regime_transition( transition_func=transition_func, target_ids=(*own.values(), *ng.values()) ), active=make_active_func(spec), states=states, - state_transitions=build_state_transitions( - spec, solver="brute_force" if dcegm_solver is None else "dcegm" - ), + state_transitions=build_state_transitions(spec, solver=state_solver), actions=build_actions(spec, grids), functions=_build_functions(spec), - constraints=constraints, **solver_kwargs, ) diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index d7b66ad..9624ed6 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -9,10 +9,9 @@ import jax.numpy as jnp from lcm import Regime -from lcm.solvers import DCEGM +from lcm.solvers import BQSEGM, DCEGM from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance from aca_model.baseline.regimes._common import ( @@ -83,7 +82,11 @@ def _build_functions(spec: RegimeSpec) -> dict: def build_regime( - name: str, grids: Grids, *, dcegm_solver: DCEGM | None = None + name: str, + grids: Grids, + *, + dcegm_solver: DCEGM | None = None, + bqsegm_solver: BQSEGM | None = None, ) -> Regime: """Build a tied regime (all tied regimes are canwork).""" spec = REGIME_SPECS[name] @@ -93,19 +96,21 @@ def build_regime( transition_func = _make_transition_canwork(gets_mc, own, ng) states = build_states(spec, grids) - solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver} + egm_solver = dcegm_solver if dcegm_solver is not None else bqsegm_solver + solver_kwargs: dict = {} if egm_solver is None else {"solver": egm_solver} + state_solver = ( + "brute_force" + if egm_solver is None + else ("bqsegm" if bqsegm_solver is not None else "dcegm") + ) return Regime( transition=build_granular_regime_transition( transition_func=transition_func, target_ids=(*own.values(), *ng.values()) ), active=make_active_func(spec), states=states, - state_transitions=build_state_transitions( - spec, solver="brute_force" if dcegm_solver is None else "dcegm" - ), + state_transitions=build_state_transitions(spec, solver=state_solver), actions=build_actions(spec, grids), functions=_build_functions(spec), - # `borrowing_constraint` is broadcast from the model level. - constraints={"positive_leisure": preferences.positive_leisure}, **solver_kwargs, ) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 02e83ce..7338265 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -84,6 +84,14 @@ class GridConfig: # gathered grid (value function unchanged). `0` keeps the whole grid in one # kernel. Only consulted under `solver="dcegm"`. n_savings_batch_size: int = 0 + # Block size for splaying the BQSEGM continuation's child stochastic-node + # expectation (health, health-cost shocks, the wage residual). `0` reads the + # whole joint node mesh in one pass — fast, but its peak intermediate scales + # with the full ride-along × node × child-grid product. A positive value loops + # the mesh in blocks of that size, trading runtime for a much smaller peak; `1` + # (one node at a time) is the memory-minimal setting for a CPU validation grid. + # Only consulted under `solver="bqsegm"`. + n_bqsegm_stochastic_node_batch_size: int = 0 MODEL_CONFIG = ModelConfig() diff --git a/src/aca_model/environment/taxes.py b/src/aca_model/environment/taxes.py index bde97a6..30efdff 100644 --- a/src/aca_model/environment/taxes.py +++ b/src/aca_model/environment/taxes.py @@ -10,9 +10,17 @@ from typing import Any, cast import jax.numpy as jnp +import lcm from lcm.params import MappingLeaf from lcm.typing import DiscreteState, FloatND, IntND +# Federal income tax bracket edges declared for the BQSEGM M1 slice. The +# progressive tax is continuous and piecewise-affine in `gross_income`, kinking +# at each finite bracket edge `income_tax_schedule.brackets_upper[spousal_income, +# k]`; the final edge is the +inf top-bracket sentinel and is not a kink. The +# decorator is metadata-only — brute and DC-EGM solve identically. +_N_INCOME_TAX_KINKS = 7 + def gross_income( capital_income: FloatND, @@ -85,6 +93,19 @@ def taxable_ss_benefit( ) +@lcm.piecewise_affine( + "after_tax_income", + variable="gross_income", + breakpoints=tuple( + lcm.affine_breakpoint( + "income_tax_schedule.brackets_upper", + kind="continuous_kink", + indexed_by="spousal_income", + static_index=k, + ) + for k in range(_N_INCOME_TAX_KINKS) + ), +) def after_tax_income( gross_income: FloatND, ss_benefit: FloatND, diff --git a/tests/test_bqsegm_model_creation.py b/tests/test_bqsegm_model_creation.py new file mode 100644 index 0000000..635bfc4 --- /dev/null +++ b/tests/test_bqsegm_model_creation.py @@ -0,0 +1,164 @@ +"""BQSEGM solver wiring: `solver="bqsegm"` is a per-regime option. + +Unlike DC-EGM (a global Euler solver on every living regime), BQSEGM solves a +single 1-D consumption/savings regime with at most one discrete action, so it +attaches only to the M1 vertical-slice regime `nongroup_nomc_inelig_canwork`; +every other living regime keeps brute force. The savings-form spec is shared +with DC-EGM (BQSEGM's budget node is `resources`, the post-decision function is +`savings`). +""" + +from collections.abc import Mapping +from typing import cast + +from helpers.model import _DERIVED_CATEGORICALS # ty: ignore[unresolved-import] +from lcm import DiscreteGrid, Model, Regime +from lcm.solvers import BQSEGM, GridSearch + +from aca_model.agent import assets_and_income +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.baseline.model import create_model +from aca_model.baseline.regimes import ( + REGIME_SPECS, + SolverName, + build_all_regimes, +) +from aca_model.baseline.regimes._bqsegm import build_bqsegm_solver +from aca_model.baseline.regimes._common import Grids, build_grids +from aca_model.benchmark import get_benchmark_params +from aca_model.config import BENCHMARK_GRID_CONFIG + +_FIXED_PARAMS, _WAGE_PARAMS, _ = get_benchmark_params(model=None) + +_M1_REGIME = "nongroup_nomc_inelig_canwork" +_BRUTE_REGIME = "retiree_nomc_inelig_canwork" + + +def _build_regimes(solver: SolverName) -> dict[str, Regime]: + return build_all_regimes( + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + solver=solver, + ) + + +def _build_model(solver: SolverName) -> Model: + return create_model( + n_subjects=1, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + derived_categoricals=_DERIVED_CATEGORICALS, + grid_config=BENCHMARK_GRID_CONFIG, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + solver=solver, + ) + + +def _grids() -> Grids: + return build_grids( + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + + +def test_bqsegm_attaches_only_to_the_m1_regime() -> None: + """`solver="bqsegm"` gives the M1 slice regime a `BQSEGM` config and leaves + every other living regime on brute force.""" + regimes = _build_regimes("bqsegm") + assert isinstance(regimes[_M1_REGIME].solver, BQSEGM) + for name in REGIME_SPECS: + if name == _M1_REGIME: + continue + assert isinstance(regimes[name].solver, GridSearch), name + + +def test_build_bqsegm_solver_uses_the_savings_form_resources_budget() -> None: + """The BQSEGM config inverts against `resources` in post-decision savings + form, matching the DC-EGM contract the regime is rewired into.""" + solver = build_bqsegm_solver(_grids()) + assert isinstance(solver, BQSEGM) + assert solver.budget_target == "resources" + assert solver.post_decision_function == "savings" + + +def test_build_bqsegm_solver_names_assets_as_the_euler_axis() -> None: + """`assets` is the liquid (Euler) axis; `aime` and the stochastic shock grids + ride along, so the solver names the Euler axis explicitly.""" + solver = build_bqsegm_solver(_grids()) + assert solver.continuous_state == "assets" + + +def test_bqsegm_m1_regime_fixes_buy_private() -> None: + """The BQSEGM M1 slice drops `buy_private` as an action (fixed to purchase), + so the only choice is continuous consumption; the brute M1 regime keeps it.""" + bqsegm_m1 = _build_regimes("bqsegm")[_M1_REGIME] + brute_m1 = _build_regimes("brute_force")[_M1_REGIME] + assert "buy_private" not in bqsegm_m1.actions + assert "buy_private" in brute_m1.actions + + +def test_bqsegm_m1_regime_fixes_labor_supply() -> None: + """The BQSEGM M1 slice drops `labor_supply` as an action (fixed to full-time + work), so no discrete action remains and the only choice is continuous + consumption; the brute M1 regime keeps `labor_supply`.""" + bqsegm_m1 = _build_regimes("bqsegm")[_M1_REGIME] + brute_m1 = _build_regimes("brute_force")[_M1_REGIME] + assert "labor_supply" not in bqsegm_m1.actions + assert "labor_supply" in brute_m1.actions + + +def test_bqsegm_m1_regime_has_no_discrete_action() -> None: + """With both discrete actions fixed, the BQSEGM M1 slice leaves only the + continuous consumption choice — no `DiscreteGrid` action remains.""" + bqsegm_m1 = _build_regimes("bqsegm")[_M1_REGIME] + assert not any( + isinstance(grid, DiscreteGrid) for grid in bqsegm_m1.actions.values() + ) + + +def test_bqsegm_m1_regime_takes_the_savings_form_assets_laws() -> None: + """The M1 regime under BQSEGM consumes the post-decision assets laws, like + DC-EGM; the other (brute) regimes keep the cash-on-hand form.""" + regimes = _build_regimes("bqsegm") + assets_laws = cast( + "Mapping[str, object]", regimes[_M1_REGIME].state_transitions["assets"] + ) + for target_name, law in assets_laws.items(): + expected = ( + assets_and_income.next_assets_when_dead_from_savings + if target_name == "dead" + else assets_and_income.next_assets_from_savings + ) + assert law is expected, target_name + + +def test_bqsegm_savings_form_functions_are_scoped_to_the_m1_regime() -> None: + """Under BQSEGM only the M1 regime carries the savings-form budget functions + (`resources`, `savings`); brute regimes keep the cash-on-hand form and carry + neither.""" + model = _build_model("bqsegm") + m1_functions = model.user_regimes[_M1_REGIME].functions + assert "resources" in m1_functions + assert "savings" in m1_functions + assert "resources" not in model.user_regimes[_BRUTE_REGIME].functions + + +def test_bqsegm_m1_regime_does_not_carry_inverse_marginal_utility() -> None: + """BQSEGM inverts the Euler equation internally, so the M1 regime never + carries the DC-EGM `inverse_marginal_utility` function (whose + solver-supplied `marginal_continuation` would otherwise be a required + parameter).""" + model = _build_model("bqsegm") + assert "inverse_marginal_utility" not in model.user_regimes[_M1_REGIME].functions + + +def test_bqsegm_brute_regimes_keep_the_borrowing_constraint() -> None: + """BQSEGM enforces the borrowing limit through its savings grid's lower bound, + so the M1 regime drops the explicit constraint; every brute regime keeps it.""" + model = _build_model("bqsegm") + assert "borrowing_constraint" not in model.user_regimes[_M1_REGIME].constraints + assert "borrowing_constraint" in model.user_regimes[_BRUTE_REGIME].constraints diff --git a/tests/test_labor_market.py b/tests/test_labor_market.py index 18dcaa2..f26992a 100644 --- a/tests/test_labor_market.py +++ b/tests/test_labor_market.py @@ -2,11 +2,35 @@ import jax.numpy as jnp import numpy as np +import pytest from aca_model.agent import labor_market from aca_model.agent.labor_market import LaborSupply +def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None: + """`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array. + + A module-level JAX array materializes on the default device the moment the + module is imported, reserving the GPU memory pool in every process that + imports the model — including the estimation orchestrator, which only + launches GPU worker ranks and must leave the devices free for them. + """ + assert isinstance(labor_market.HOURS_VALUES, np.ndarray) + + +@pytest.mark.parametrize( + ("choice", "expected_hours"), + [(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)], +) +def test_working_hours_value_maps_choice_to_annual_hours( + choice: int, expected_hours: float +) -> None: + """Each labor-supply choice maps to its annual hours worked.""" + result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32)) + np.testing.assert_allclose(float(result), expected_hours) + + def test_wage_combines_age_health_profile_with_residual() -> None: """`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`.""" log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health] diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 733fecd..5caa981 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -67,6 +67,49 @@ def test_leisure_bad_health() -> None: assert jnp.isclose(result, 4500.0) +def test_leisure_canwork_stays_positive_when_work_cost_meets_endowment() -> None: + """Leisure bends to a strictly positive floor when work costs reach the endowment. + + Without a floor, leisure would be zero (or negative past the endowment) and feed a + non-positive base into the CRRA aggregator. The smooth floor keeps it clearly + positive. + """ + result = preferences.leisure_canwork_retiree_or_nongroup( + working_hours_value=jnp.array(4500.0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(1), + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(0.0), + fixed_cost_of_work=jnp.asarray(500.0), # 4500 + 500 == 5000 == endowment + labor_force_reentry_cost=jnp.asarray(0.0), + ) + assert result > 1.0 + + +def test_leisure_canwork_tied_decreases_smoothly_into_saturation() -> None: + """Past the endowment, more work cost still lowers leisure, never below zero. + + A flat clamp or a raw `endowment - cost` would either pin leisure or drive it + negative; the smooth floor keeps it strictly positive and strictly decreasing. + """ + common = { + "good_health": jnp.int32(1), + "time_endowment": jnp.asarray(5000.0), + "leisure_cost_of_bad_health": jnp.asarray(0.0), + } + at_endowment = preferences.leisure_canwork_tied( + working_hours_value=jnp.array(4500.0), + fixed_cost_of_work=jnp.asarray(500.0), # cost == endowment + **common, + ) + past_endowment = preferences.leisure_canwork_tied( + working_hours_value=jnp.array(4500.0), + fixed_cost_of_work=jnp.asarray(700.0), # cost exceeds endowment by 200 + **common, + ) + assert at_endowment > past_endowment > 0.0 + + def test_utility_positive_leisure() -> None: result = preferences.u_alive( consumption_equiv=jnp.array(10000.0),