diff --git a/quantecon/markov/ddp.py b/quantecon/markov/ddp.py index 25a99850..bdf0bf70 100644 --- a/quantecon/markov/ddp.py +++ b/quantecon/markov/ddp.py @@ -119,6 +119,7 @@ _fill_dense_Q, _s_wise_max_argmax, _s_wise_max, _find_indices, _has_sorted_sa_indices, _generate_a_indptr ) +from numba import njit class DiscreteDP: @@ -370,7 +371,8 @@ def __init__(self, R, Q, beta, s_indices=None, a_indices=None): def s_wise_max(vals, out=None, out_argmax=None): """ Return the vector max_a vals(s, a), where vals is represented - by a 1-dimensional ndarray of shape (self.num_sa_pairs,). + by a 2-dimensional ndarray of shape (n, m). Stored in out, + which must be of length self.num_states. out and out_argmax must be of length self.num_states; dtype of out_argmax must be int. @@ -378,11 +380,9 @@ def s_wise_max(vals, out=None, out_argmax=None): if out is None: out = np.empty(self.num_states) if out_argmax is None: - _s_wise_max(self.a_indices, self.a_indptr, vals, - out_max=out) + _njit_s_wise_max_1d(self.a_indices, self.a_indptr, vals, out) else: - _s_wise_max_argmax(self.a_indices, self.a_indptr, vals, - out_max=out, out_argmax=out_argmax) + _njit_s_wise_max_argmax_1d(self.a_indices, self.a_indptr, vals, out, out_argmax) return out self.s_wise_max = s_wise_max @@ -411,11 +411,7 @@ def s_wise_max(vals, out=None, out_argmax=None): """ if out is None: out = np.empty(self.num_states) - if out_argmax is None: - vals.max(axis=1, out=out) - else: - vals.argmax(axis=1, out=out_argmax) - out[:] = vals[np.arange(self.num_states), out_argmax] + _njit_s_wise_max_2d(vals, out, out_argmax) return out self.s_wise_max = s_wise_max @@ -1078,3 +1074,37 @@ def backward_induction(ddp, T, v_term=None): ddp.bellman_operator(vs[t, :], Tv=vs[t-1, :], sigma=sigmas[t-1, :]) return vs, sigmas + + +# 1D s_wise_max using Numba njit +@njit(cache=True) +def _njit_s_wise_max_1d(a_indices, a_indptr, vals, out_max): + _s_wise_max(a_indices, a_indptr, vals, out_max=out_max) + +@njit(cache=True) +def _njit_s_wise_max_argmax_1d(a_indices, a_indptr, vals, out_max, out_argmax): + _s_wise_max_argmax(a_indices, a_indptr, vals, out_max=out_max, out_argmax=out_argmax) + + +# 2D s_wise_max using Numba njit +@njit(cache=True) +def _njit_s_wise_max_2d(vals, out, out_argmax=None): + n, m = vals.shape + if out_argmax is None: + for i in range(n): + maxval = vals[i, 0] + for j in range(1, m): + if vals[i, j] > maxval: + maxval = vals[i, j] + out[i] = maxval + else: + for i in range(n): + argmax = 0 + maxval = vals[i, 0] + for j in range(1, m): + if vals[i, j] > maxval: + maxval = vals[i, j] + argmax = j + out_argmax[i] = argmax + out[i] = maxval + return out