Skip to content

Commit 005b3ea

Browse files
author
Yancey1989
committed
cluster train data
1 parent e3e106d commit 005b3ea

4 files changed

Lines changed: 119 additions & 0 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.vscode/

doc/fluid/api/api_guides/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ API使用指南
77

88
high_low_level_api.md
99
low_level/layers/index.rst
10+
low_level/cluster/index.rst
1011
low_level/executor.rst
1112
low_level/optimizer.rst
1213
low_level/metrics.rst
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
.. _api_guide_cluster_train_data:
2+
3+
####################
4+
分布式训练数据准备
5+
####################
6+
7+
一个数据并行的分布式训练任务通常会含有多个训练节点,每个训练节点负责训练整个数据集种的一部分。所以在
8+
启动分布式训练任务之前需要将训练数据切分成多个小文件,再实现一个多机训练的reader函数根据当前节点的
9+
唯一序号(trainer_id)以及当前训练任务中训练节点的总数(trainers)决定读取哪一部分训练数据。
10+
11+
准备文本格式的分布式训练数据集
12+
------------------------------
13+
14+
训练数据切分
15+
~~~~~~~~~~~~
16+
17+
简单的,对于文本类训练数据来说,我们可以使用 split 命令将训练数据切分成多个小文件,例如:
18+
19+
.. code-block:: bash
20+
$ split -d -a 4 -d -l 100 housing.data cluster/housing.data.
21+
$ find ./cluster
22+
cluster/
23+
cluster/housing.data.0002
24+
cluster/housing.data.0003
25+
cluster/housing.data.0004
26+
cluster/housing.data.0000
27+
cluster/housing.data.0001
28+
cluster/housing.data.0005
29+
30+
读取分布式训练数据集
31+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32+
33+
在数据并行场景下,我们需要将训练数据平均分配给每个训练节点,通常的方法是实现一个函数,使之能够
34+
根据当前任务的训练节点数量以及当前节点的唯一序号决定需要读取哪些文件,例如:
35+
36+
.. code-block:: python
37+
38+
def gen_train_list(file_pattern, trainers, trainer_id):
39+
file_list = glob.glob(file_pattern)
40+
ret_list = []
41+
for idx, f in enumerate(file_list):
42+
if (idx + trainers) % trainers == trainer_id:
43+
ret_list.append(f)
44+
return ret_list
45+
46+
- file_pattern: 训练数据文件目录目录,上述例子可以是 `cluster/housing.data.*`
47+
- trainers: 当前任务的训练节点数。
48+
- trainer_id: 当前训练节点的唯一序号。
49+
50+
准备 RecordIO 格式的分布式训练数据集
51+
-------------------------------------
52+
53+
对于非文本类数据,可以预先将训练数据转换为 RecordIO 格式再进行训练, 并且转换成 RecordIO 格式
54+
的另一个好处是可以提升 IO 效率,从而提升分布式训练任务的运行效率。
55+
56+
57+
生成 RecordIO 格式数据集
58+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
59+
60+
Fluid 提供了 `fluid.recordio_writer.convert_reader_to_recordio_files` API, 可以将训练数据转换成
61+
RecordIO 格式, 样例代码如下
62+
63+
.. code-block:: python
64+
65+
reader = paddle.batch(mnist.train(), batch_size=1)
66+
feeder = fluid.DataFeeder(
67+
feed_list=[ # order is image and label
68+
fluid.layers.data(
69+
name='image', shape=[784]),
70+
fluid.layers.data(
71+
name='label', shape=[1], dtype='int64'),
72+
],
73+
place=fluid.CPUPlace())
74+
fluid.recordio_writer.convert_reader_to_recordio_files(
75+
filename_suffix='./mnist.recordio', batch_per_file=100, reader, feeder)
76+
77+
运行上述代码将会生成以下文件:
78+
79+
.. code-block:: bash
80+
81+
.
82+
\_mnist-00000.recordio
83+
|-mnist-00001.recordio
84+
|-mnist-00002.recordio
85+
|-mnist-00003.recordio
86+
|-mnist-00004.recordio
87+
88+
API Reference 请参考::ref:`api_fluid_recordio_writer_convert_reader_to_recordio_file`
89+
90+
读取 RecordIO 训练数据
91+
~~~~~~~~~~~~~~~~~~~~~~~~
92+
93+
Fluid 种提供了 `fluid.layers.io.open_files` API 来读取 RecordIO 格式的训练数据,在以下样例代码
94+
中复用了上面例子中 `gen_train_list` 函数来决定当前节点应该读取哪一部分训练数据:
95+
96+
.. code-block:: python
97+
98+
trainers = int(os.getenv("PADDLE_TRAINERS"))
99+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
100+
data_file = fluid.layers.io.open_files(
101+
filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0),
102+
thread_num=1,
103+
shapes=[(-1, 784),(-1, 1)],
104+
lod_levels=[0, 0],
105+
dtypes=["float32", "int32"])
106+
img, label = fluid.layers.io.read_file(data_files)
107+
108+
API Reference 请参考: :ref:`api_fluid_layers_open_files`
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
==========
2+
多机训练
3+
==========
4+
5+
.. toctree::
6+
:maxdepth: 1
7+
8+
cluster_train_data_cn.rst
9+

0 commit comments

Comments
 (0)