Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions quantecon/markov/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
_fill_dense_Q, _s_wise_max_argmax, _s_wise_max, _find_indices,
_has_sorted_sa_indices, _generate_a_indptr
)
from numba import njit


class DiscreteDP:
Expand Down Expand Up @@ -370,19 +371,18 @@ def __init__(self, R, Q, beta, s_indices=None, a_indices=None):
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 1-dimensional ndarray of shape (self.num_sa_pairs,).
by a 2-dimensional ndarray of shape (n, m). Stored in out,
which must be of length self.num_states.
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.

"""
if out is None:
out = np.empty(self.num_states)
if out_argmax is None:
_s_wise_max(self.a_indices, self.a_indptr, vals,
out_max=out)
_njit_s_wise_max_1d(self.a_indices, self.a_indptr, vals, out)
else:
_s_wise_max_argmax(self.a_indices, self.a_indptr, vals,
out_max=out, out_argmax=out_argmax)
_njit_s_wise_max_argmax_1d(self.a_indices, self.a_indptr, vals, out, out_argmax)
return out

self.s_wise_max = s_wise_max
Expand Down Expand Up @@ -411,11 +411,7 @@ def s_wise_max(vals, out=None, out_argmax=None):
"""
if out is None:
out = np.empty(self.num_states)
if out_argmax is None:
vals.max(axis=1, out=out)
else:
vals.argmax(axis=1, out=out_argmax)
out[:] = vals[np.arange(self.num_states), out_argmax]
_njit_s_wise_max_2d(vals, out, out_argmax)
return out

self.s_wise_max = s_wise_max
Expand Down Expand Up @@ -1078,3 +1074,37 @@ def backward_induction(ddp, T, v_term=None):
ddp.bellman_operator(vs[t, :], Tv=vs[t-1, :], sigma=sigmas[t-1, :])

return vs, sigmas


# 1D s_wise_max using Numba njit
@njit(cache=True)
def _njit_s_wise_max_1d(a_indices, a_indptr, vals, out_max):
_s_wise_max(a_indices, a_indptr, vals, out_max=out_max)

@njit(cache=True)
def _njit_s_wise_max_argmax_1d(a_indices, a_indptr, vals, out_max, out_argmax):
_s_wise_max_argmax(a_indices, a_indptr, vals, out_max=out_max, out_argmax=out_argmax)


# 2D s_wise_max using Numba njit
@njit(cache=True)
def _njit_s_wise_max_2d(vals, out, out_argmax=None):
n, m = vals.shape
if out_argmax is None:
for i in range(n):
maxval = vals[i, 0]
for j in range(1, m):
if vals[i, j] > maxval:
maxval = vals[i, j]
out[i] = maxval
else:
for i in range(n):
argmax = 0
maxval = vals[i, 0]
for j in range(1, m):
if vals[i, j] > maxval:
maxval = vals[i, j]
argmax = j
out_argmax[i] = argmax
out[i] = maxval
return out