|
54 | 54 |
|
55 | 55 |
|
56 | 56 |
|
57 | | - array([ 0.00369344, -0.03429507, -0.01185087, 0.04203845]) |
| 57 | + array([ 0.03196014, 0.01341673, -0.02871294, 0.04743216]) |
58 | 58 |
|
59 | 59 |
|
60 | 60 |
|
|
105 | 105 |
|
106 | 106 |
|
107 | 107 |
|
108 | | - array([ 0.00300754, -0.2292451 , -0.0110101 , 0.33095888]) |
| 108 | + array([ 0.03222848, -0.18128199, -0.0277643 , 0.33091941]) |
109 | 109 |
|
110 | 110 |
|
111 | 111 |
|
@@ -212,7 +212,7 @@ np.mean(totals), np.median(totals), np.std(totals), np.min(totals), np.max(total |
212 | 212 |
|
213 | 213 |
|
214 | 214 |
|
215 | | - (41.936, 41.0, 9.222358917326956, 25.0, 68.0) |
| 215 | + (42.144, 40.5, 8.881399889657034, 24.0, 68.0) |
216 | 216 |
|
217 | 217 |
|
218 | 218 |
|
@@ -1149,7 +1149,7 @@ breakout_env |
1149 | 1149 |
|
1150 | 1150 |
|
1151 | 1151 |
|
1152 | | - <tf_agents.environments.wrappers.TimeLimit at 0x13e300f50> |
| 1152 | + <tf_agents.environments.wrappers.TimeLimit at 0x140775050> |
1153 | 1153 |
|
1154 | 1154 |
|
1155 | 1155 |
|
@@ -1460,6 +1460,183 @@ agent.initialize() |
1460 | 1460 |
|
1461 | 1461 | ### Creating the replay buffer and corresponding observer |
1462 | 1462 |
|
| 1463 | +The TF-Agents library provides replay buffers implemented in pure Python (submodules beginning with `py_`) or in TF (submodules beginning with `tf_`). |
| 1464 | +We will use the `TFUniformReplayBuffer` because it provides a high-performance implementation of a replay buffer with uniform sampling. |
| 1465 | +Some discussion of the parameters are discussed below the code. |
| 1466 | + |
| 1467 | + |
| 1468 | +```python |
| 1469 | +from tf_agents.replay_buffers import tf_uniform_replay_buffer |
| 1470 | + |
| 1471 | +replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( |
| 1472 | + data_spec=agent.collect_data_spec, |
| 1473 | + batch_size=tf_env.batch_size, |
| 1474 | + max_length=1000000 |
| 1475 | +) |
| 1476 | +``` |
| 1477 | + |
| 1478 | +* `data_spec`: A `Trajectory` with a description of the various data types reported by the agent. |
| 1479 | +* `batch_size`: The number of trajectories that will be added at each step. For our purposes this will be 1, but a *batched environment* can group a series of steps into one batch. |
| 1480 | +* `max_length`: The maximum size of the buffer. This number was taken from the 2015 DQN paper again and is quite large. |
| 1481 | + |
| 1482 | +Now we can create the observer that will write the trajectories to the replay buffer. |
| 1483 | +An observer is just a function (or callable object) that takes a trajectory argument. |
| 1484 | +Thus, we will use the `add_batch()` method from the replay buffer to be the observer. |
| 1485 | + |
| 1486 | + |
| 1487 | +```python |
| 1488 | +replay_buffer_observer = replay_buffer.add_batch |
| 1489 | +``` |
| 1490 | + |
| 1491 | +### Creating training metrics |
| 1492 | + |
| 1493 | +The TF-Agents library provides several RL metrics in the `tf_agents.metrics` module, again, some implemented in pure Python and others using TF. |
| 1494 | +We will use a few to count the number of episodes, number of steps taken, the average return per episode, and average episode length. |
| 1495 | + |
| 1496 | + |
| 1497 | +```python |
| 1498 | +from tf_agents.metrics import tf_metrics |
| 1499 | + |
| 1500 | +train_metrics = [ |
| 1501 | + tf_metrics.NumberOfEpisodes(), |
| 1502 | + tf_metrics.EnvironmentSteps(), |
| 1503 | + tf_metrics.AverageReturnMetric(), |
| 1504 | + tf_metrics.AverageEpisodeLengthMetric() |
| 1505 | +] |
| 1506 | +``` |
| 1507 | + |
| 1508 | +We can the values of each of these metrics using their `result()` methods. |
| 1509 | +Alternatively, we can log all of the metrics by calling `log_metrics(train_metrics)`. |
| 1510 | + |
| 1511 | + |
| 1512 | +```python |
| 1513 | +from tf_agents.eval.metric_utils import log_metrics |
| 1514 | +import logging |
| 1515 | + |
| 1516 | +logging.basicConfig(level=logging.INFO) |
| 1517 | +log_metrics(train_metrics) |
| 1518 | +``` |
| 1519 | + |
| 1520 | +### Creating the collect driver |
| 1521 | + |
| 1522 | +The driver is the object the explores an environment using a policy, collects experiences, and broadcasts them to some observers. |
| 1523 | +At each step, the following can happen: |
| 1524 | + |
| 1525 | +1. The driver passes the current time step to the collect policy, which uses the time step to choose an action and return and *action step* object containing the action to take. |
| 1526 | +2. The driver passes the action to the environment which returns the next time step. |
| 1527 | +3. The drivers creates a trajectory object to represent the this transition and broadcasts it to all the observers. |
| 1528 | + |
| 1529 | +There are two main driver classes available through TF-Agents: `DynamicStepDriver` and `DynamicEpisodeDriver`. |
| 1530 | +The former collections experiences for a given number of steps and the latter collects experiences for a given number of episodes. |
| 1531 | +We will use the `DynamicStepDriver` to collect experiences for 4 steps for each training iteration (as was done in the 2015 paper). |
| 1532 | + |
| 1533 | + |
| 1534 | +```python |
| 1535 | +from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver |
| 1536 | + |
| 1537 | +collect_driver = DynamicStepDriver( |
| 1538 | + tf_env, |
| 1539 | + agent.collect_policy, |
| 1540 | + observers=[replay_buffer_observer] + train_metrics, |
| 1541 | + num_steps=update_period |
| 1542 | +) |
| 1543 | +``` |
| 1544 | + |
| 1545 | +We could not train it by calling its `run()` method, but its best to "warm-up" the replay buffer using a random policy, first. |
| 1546 | +For this, we will use the `RandomTFPolicy` and create a second driver to run 20,000 steps. |
| 1547 | + |
| 1548 | + |
| 1549 | +```python |
| 1550 | +from tf_agents.policies.random_tf_policy import RandomTFPolicy |
| 1551 | + |
| 1552 | +initial_collect_policy = RandomTFPolicy( |
| 1553 | + tf_env.time_step_spec(), |
| 1554 | + tf_env.action_spec() |
| 1555 | +) |
| 1556 | + |
| 1557 | +init_driver = DynamicStepDriver( |
| 1558 | + tf_env, |
| 1559 | + initial_collect_policy, |
| 1560 | + observers=[replay_buffer.add_batch], |
| 1561 | + num_steps=20000 |
| 1562 | +) |
| 1563 | + |
| 1564 | +final_time_step, final_policy_state = init_driver.run() |
| 1565 | +``` |
| 1566 | + |
| 1567 | + |
| 1568 | +```python |
| 1569 | +final_time_step |
| 1570 | +``` |
| 1571 | + |
| 1572 | + |
| 1573 | + |
| 1574 | + |
| 1575 | + TimeStep(step_type=<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>, reward=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, discount=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, observation=<tf.Tensor: shape=(1, 84, 84, 4), dtype=uint8, numpy= |
| 1576 | + array([[[[0, 0, 0, 0], |
| 1577 | + [0, 0, 0, 0], |
| 1578 | + [0, 0, 0, 0], |
| 1579 | + ..., |
| 1580 | + [0, 0, 0, 0], |
| 1581 | + [0, 0, 0, 0], |
| 1582 | + [0, 0, 0, 0]], |
| 1583 | + |
| 1584 | + [[0, 0, 0, 0], |
| 1585 | + [0, 0, 0, 0], |
| 1586 | + [0, 0, 0, 0], |
| 1587 | + ..., |
| 1588 | + [0, 0, 0, 0], |
| 1589 | + [0, 0, 0, 0], |
| 1590 | + [0, 0, 0, 0]], |
| 1591 | + |
| 1592 | + [[0, 0, 0, 0], |
| 1593 | + [0, 0, 0, 0], |
| 1594 | + [0, 0, 0, 0], |
| 1595 | + ..., |
| 1596 | + [0, 0, 0, 0], |
| 1597 | + [0, 0, 0, 0], |
| 1598 | + [0, 0, 0, 0]], |
| 1599 | + |
| 1600 | + ..., |
| 1601 | + |
| 1602 | + [[0, 0, 0, 0], |
| 1603 | + [0, 0, 0, 0], |
| 1604 | + [0, 0, 0, 0], |
| 1605 | + ..., |
| 1606 | + [0, 0, 0, 0], |
| 1607 | + [0, 0, 0, 0], |
| 1608 | + [0, 0, 0, 0]], |
| 1609 | + |
| 1610 | + [[0, 0, 0, 0], |
| 1611 | + [0, 0, 0, 0], |
| 1612 | + [0, 0, 0, 0], |
| 1613 | + ..., |
| 1614 | + [0, 0, 0, 0], |
| 1615 | + [0, 0, 0, 0], |
| 1616 | + [0, 0, 0, 0]], |
| 1617 | + |
| 1618 | + [[0, 0, 0, 0], |
| 1619 | + [0, 0, 0, 0], |
| 1620 | + [0, 0, 0, 0], |
| 1621 | + ..., |
| 1622 | + [0, 0, 0, 0], |
| 1623 | + [0, 0, 0, 0], |
| 1624 | + [0, 0, 0, 0]]]], dtype=uint8)>) |
| 1625 | + |
| 1626 | + |
| 1627 | + |
| 1628 | + |
| 1629 | +```python |
| 1630 | +final_policy_state |
| 1631 | +``` |
| 1632 | + |
| 1633 | + |
| 1634 | + |
| 1635 | + |
| 1636 | + () |
| 1637 | + |
| 1638 | + |
| 1639 | + |
1463 | 1640 |
|
1464 | 1641 | ```python |
1465 | 1642 |
|
|
0 commit comments