Skip to content
Open
11 changes: 8 additions & 3 deletions src/aca_model/agent/labor_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import jax.numpy as jnp
import numpy as np
from lcm import categorical
from lcm.typing import (
ContinuousState,
Expand Down Expand Up @@ -39,12 +40,16 @@ class SpousalIncome:
married_has_inc: ScalarInt


HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])
# Host array, not a module-level JAX array: a device array here would
# reserve the GPU memory pool at import time in every process that imports
# the model. It is converted to a device array at each indexing site, where
# the value folds into the surrounding compiled function.
HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])


def working_hours_value(labor_supply: DiscreteAction) -> FloatND:
"""Map labor supply choice to annual hours worked."""
return HOURS_VALUES[labor_supply]
return jnp.asarray(HOURS_VALUES)[labor_supply]


def wage(
Expand Down Expand Up @@ -74,7 +79,7 @@ def income(

income = wage * hours^(1 + exp) * int^(-exp)
"""
hours = HOURS_VALUES[labor_supply]
hours = jnp.asarray(HOURS_VALUES)[labor_supply]
return jnp.where(
hours > 0.0,
wage
Expand Down
36 changes: 27 additions & 9 deletions src/aca_model/agent/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lcm import categorical
from lcm.typing import (
Age,
BoolND,
ContinuousAction,
ContinuousState,
DiscreteState,
Expand All @@ -19,6 +18,11 @@

from aca_model.agent.labor_market import LaggedLaborSupply

# Width of the smooth leisure floor, as a fraction of the time endowment. Small enough
# that leisure equals `time_endowment - cost` wherever work costs sit well below the
# endowment; it only bends the map near and beyond the endowment.
_LEISURE_SMOOTHING_FRACTION = 0.01


@categorical(ordered=False)
class PrefType:
Expand All @@ -44,11 +48,6 @@ class BenchmarkPrefType:
type_1: ScalarInt


def positive_leisure(leisure: FloatND) -> BoolND:
"""Return True where leisure is strictly positive."""
return leisure > 0


def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND:
"""Return the equivalence scale for household size adjustment.

Expand All @@ -69,6 +68,22 @@ def fixed_cost_of_work(
)


def _smooth_leisure_floor(
leisure_available: FloatND, time_endowment: ScalarFloat
) -> FloatND:
"""Bend leisure to a strictly positive floor as work costs approach the endowment.

`softplus(x) = log(1 + e^x)` via `jnp.logaddexp(0, x)`, scaled by a small fraction
of the endowment. Where `leisure_available` is large relative to the smoothing width
the map reduces to `leisure_available` (bulk unchanged); as it falls to zero leisure
bends to `0⁺` — never negative, never a kinked clamp — so the CRRA aggregator never
receives a non-positive base. The smoothing width scales with the endowment, so the
map is scale-invariant.
"""
smoothing = _LEISURE_SMOOTHING_FRACTION * time_endowment
return smoothing * jnp.logaddexp(0.0, leisure_available / smoothing)


def leisure_canwork_retiree_or_nongroup(
working_hours_value: FloatND,
good_health: IntND,
Expand All @@ -94,7 +109,8 @@ def leisure_canwork_retiree_or_nongroup(
0.0,
)

return time_endowment - health_loss - work_loss
leisure_available = time_endowment - health_loss - work_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def leisure_canwork_tied(
Expand All @@ -112,7 +128,8 @@ def leisure_canwork_tied(
work_loss = jnp.where(
working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0
)
return time_endowment - health_loss - work_loss
leisure_available = time_endowment - health_loss - work_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def leisure_forcedout(
Expand All @@ -122,7 +139,8 @@ def leisure_forcedout(
) -> FloatND:
"""Compute leisure for forcedout regimes (no work)."""
health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health)
return time_endowment - health_loss
leisure_available = time_endowment - health_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def consumption_equiv(
Expand Down
24 changes: 20 additions & 4 deletions src/aca_model/baseline/regimes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from typing import Any

from lcm import DiscreteGrid, Regime
from lcm.solvers import DCEGM
from lcm.solvers import BQSEGM, DCEGM
from lcm.typing import UserParams

from aca_model.baseline.regimes import _nongroup as nongroup
from aca_model.baseline.regimes import _retiree as retiree
from aca_model.baseline.regimes import _tied as tied
from aca_model.baseline.regimes._bqsegm import build_bqsegm_solver
from aca_model.baseline.regimes._common import (
REGIME_SPECS,
Grids,
Expand All @@ -35,6 +36,11 @@
from aca_model.baseline.regimes._dcegm import build_dcegm_solver
from aca_model.config import GridConfig

# BQSEGM is a per-regime (not global) solver: it solves one 1-D
# consumption/savings regime with at most one discrete action, so it attaches
# only to the M1 vertical-slice regime.
_BQSEGM_REGIME = "nongroup_nomc_inelig_canwork"

__all__ = [
"REGIME_SPECS",
"RegimeId",
Expand All @@ -54,7 +60,11 @@


def build_regime(
name: str, grids: Grids, *, dcegm_solver: DCEGM | None = None
name: str,
grids: Grids,
*,
dcegm_solver: DCEGM | None = None,
bqsegm_solver: BQSEGM | None = None,
) -> Regime:
"""Build a single baseline Regime object for the given regime name."""
if name == "dead":
Expand All @@ -67,7 +77,7 @@ def build_regime(
if builder is None:
msg = f"Unknown HIS type: {spec['his']}"
raise ValueError(msg)
return builder(name, grids, dcegm_solver=dcegm_solver)
return builder(name, grids, dcegm_solver=dcegm_solver, bqsegm_solver=bqsegm_solver)


def build_all_regimes(
Expand Down Expand Up @@ -98,9 +108,15 @@ def build_all_regimes(
consumption_dollars_points=consumption_dollars_points,
)
dcegm_solver = build_dcegm_solver(grids) if solver == "dcegm" else None
bqsegm_solver = build_bqsegm_solver(grids) if solver == "bqsegm" else None
regimes = {}
for name in REGIME_SPECS:
regimes[name] = build_regime(name, grids, dcegm_solver=dcegm_solver)
regimes[name] = build_regime(
name,
grids,
dcegm_solver=dcegm_solver,
bqsegm_solver=bqsegm_solver if name == _BQSEGM_REGIME else None,
)
regimes["dead"] = build_dead_regime(solver=solver)
return regimes

Expand Down
56 changes: 56 additions & 0 deletions src/aca_model/baseline/regimes/_bqsegm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""BQSEGM solver configuration for the ACA M1 vertical-slice regime.

BQSEGM is the case-piece endogenous-grid solver for a single 1-D
consumption/savings regime whose budget is split by institutional breakpoints
on a derived monotone income quantity. It shares DC-EGM's post-decision
(savings) spec — consumption is recovered from `resources = max(cash_on_hand,
floor)`, the assets laws are in savings form, the borrowing constraint is the
savings grid's lower bound — but solves only one regime with at most one
discrete action, so it attaches per regime rather than globally. The
function-level rewiring is shared with DC-EGM (`build_dcegm_functions`, the
savings-form assets laws in `_common`); this module holds only the solver
config.
"""

from lcm import IrregSpacedGrid
from lcm.solvers import BQSEGM

from aca_model.baseline.regimes._common import Grids


def build_bqsegm_solver(grids: Grids) -> BQSEGM:
"""Build the per-regime BQSEGM configuration.

The savings grid mirrors DC-EGM's: lower bound 0 (the borrowing constraint
in post-decision form), upper bound the assets span, cubically clustered
toward the constraint. The budget node is `resources` (post-floor
cash-on-hand) and the post-decision function is `savings`, matching the
shared savings-form spec.
"""
n_points = grids.grid_config.n_savings_gridpoints
_fail_if_too_few_savings_gridpoints(n_points)
assets_points = grids.assets.to_jax()
savings_stop = float(assets_points[-1]) - float(assets_points[0])
savings_grid = IrregSpacedGrid(
points=tuple(savings_stop * (i / (n_points - 1)) ** 3 for i in range(n_points)),
batch_size=grids.grid_config.n_savings_batch_size,
)
return BQSEGM(
savings_grid=savings_grid,
continuous_state="assets",
budget_target="resources",
post_decision_function="savings",
# Splay the child stochastic-node expectation per the grid config: `0` (the
# default) reads the whole node mesh in one pass on a memory-rich device; a
# positive value loops it in blocks to fit a tighter budget (a CPU run).
stochastic_node_batch_size=grids.grid_config.n_bqsegm_stochastic_node_batch_size,
)


def _fail_if_too_few_savings_gridpoints(n_savings_gridpoints: int) -> None:
if n_savings_gridpoints < 2:
msg = (
f"n_savings_gridpoints must be >= 2 to form the cubically clustered "
f"BQSEGM savings grid, got {n_savings_gridpoints}."
)
raise ValueError(msg)
48 changes: 41 additions & 7 deletions src/aca_model/baseline/regimes/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from aca_model.environment import pensions, social_security, taxes
from aca_model.environment.social_security import ClaimedSS

SolverName = Literal["brute_force", "dcegm"]
SolverName = Literal["brute_force", "dcegm", "bqsegm"]


@categorical(ordered=False)
Expand Down Expand Up @@ -454,14 +454,27 @@ def build_states(spec: RegimeSpec, grids: Grids) -> dict:
return states


def build_actions(spec: RegimeSpec, grids: Grids) -> dict:
"""Build the action dict for a non-dead regime."""
def build_actions(
spec: RegimeSpec,
grids: Grids,
*,
drop_buy_private: bool = False,
drop_labor_supply: bool = False,
) -> dict:
"""Build the action dict for a non-dead regime.

The `drop_*` flags fix a discrete action to a single level for the BQSEGM
M1 vertical slice (its case-piece envelope handles at most one discrete
action). The dropped action's former consumers are rebound to the fixed
level at the regime builder, so removing it here is the action side of the
dags remove-and-fix.
"""
actions: dict = {}
if spec["ss"] == "choose":
actions["claim_ss"] = DiscreteGrid(ClaimedSS)
if spec["canwork"] == "canwork":
if spec["canwork"] == "canwork" and not drop_labor_supply:
actions["labor_supply"] = DiscreteGrid(LaborSupply)
if spec["his"] == "nongroup" and spec["mc"] == "nomc":
if spec["his"] == "nongroup" and spec["mc"] == "nomc" and not drop_buy_private:
actions["buy_private"] = DiscreteGrid(BuyPrivate)
actions["consumption_dollars"] = grids.consumption_dollars
return actions
Expand Down Expand Up @@ -583,6 +596,11 @@ def build_model_functions(*, solver: SolverName = "brute_force") -> dict:
"""
functions: dict = {}
if solver == "dcegm":
# DC-EGM solves every living regime, so the solver-contract functions are
# broadcast model-wide. BQSEGM solves only the M1 regime, which carries
# them at regime level (see `_nongroup.build_regime`); broadcasting them
# would force every brute regime to supply the solver's
# `marginal_continuation`.
functions |= build_dcegm_functions()
functions["total_health_costs"] = health_insurance.total_costs
functions["oop_costs"] = health_insurance.oop_with_medicaid
Expand Down Expand Up @@ -635,12 +653,28 @@ def build_dcegm_functions() -> dict:
}


def build_bqsegm_functions() -> dict:
"""Build the regime functions the BQSEGM solver-contract requires.

BQSEGM inverts the Euler equation internally (CRRA from the utility
parameters), so unlike DC-EGM it needs only the savings-form budget
(`resources`) and the post-decision savings node — not
`inverse_marginal_utility`.
"""
return {
"resources": assets_and_income.resources,
"savings": assets_and_income.savings,
}


def build_model_constraints(*, solver: SolverName = "brute_force") -> dict:
"""Build the model-level constraints broadcast into every regime.

`dead` masks the borrowing constraint — it has no consumption action.
Under DC-EGM there is no explicit borrowing constraint: the savings
grid's lower bound enforces it.
grid's lower bound enforces it. BQSEGM enforces it the same way, but only
in the M1 regime it solves, so the constraint stays broadcast for the brute
regimes and the M1 regime masks it (see `_nongroup.build_regime`).
"""
if solver == "dcegm":
return {}
Expand Down Expand Up @@ -973,7 +1007,7 @@ def _build_per_target_regime_assets(
targets use the full `next_assets` with the pension correction.
Under DC-EGM both laws take their post-decision (savings) form.
"""
if solver == "dcegm":
if solver in ("dcegm", "bqsegm"):
living_law = assets_and_income.next_assets_from_savings
dead_law = assets_and_income.next_assets_when_dead_from_savings
else:
Expand Down
Loading
Loading