Skip to content

Commit b7b619a

Browse files
authored
feat(train): log norm and histograms (borisdayma#143)
* feat(train): log norm and histograms * feat: update shampoo
1 parent 7939874 commit b7b619a

File tree

3 files changed

+159
-28
lines changed

3 files changed

+159
-28
lines changed

tools/train/scalable_shampoo/distributed_shampoo.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,11 @@ def sharded_init_fn(params):
832832
if not _skip_preconditioning(param):
833833
sizes = [s[0] for s in shapes]
834834
shapes = preconditioner.shapes_for_preconditioners()
835-
statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
836-
preconditioners = [jnp.eye(max_size) for s in shapes]
835+
statistics = [
836+
matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
837+
for s in shapes
838+
]
839+
preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
837840
padded_statistics.extend(statistics)
838841
padded_preconditioners.extend(preconditioners)
839842
exponent = (
@@ -1244,8 +1247,10 @@ def _init(param):
12441247
preconditioners = []
12451248
if not _skip_preconditioning(param):
12461249
shapes = preconditioner.shapes_for_preconditioners()
1247-
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
1248-
preconditioners = [jnp.eye(s[0]) for s in shapes]
1250+
statistics = [
1251+
matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
1252+
]
1253+
preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
12491254

12501255
diagonal_statistics = []
12511256
if _graft_type_has_diagonal_statistics():

tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
1717

1818
import functools
19-
from typing import List, Union
19+
from typing import Any, List, Sequence, Union
2020

2121
import jax
2222
import jax.numpy as jnp
23+
import numpy as np
2324
from flax import struct
2425
from jax import lax
2526

@@ -41,6 +42,7 @@ class SlicedSymmetricMatrix:
4142
def product_with_transpose(
4243
mat1,
4344
mat2,
45+
axes,
4446
precision=lax.Precision.DEFAULT,
4547
):
4648
"""Returns mat1 * mat2^T for two matrices (possibly batched).
@@ -50,50 +52,85 @@ def product_with_transpose(
5052
Args:
5153
mat1: First matrix.
5254
mat2: Second matrix.
55+
axes: The axes over which to apply the product.
5356
precision: JAX precision to use for the multiplication.
5457
"""
55-
return jnp.einsum("...ij,...kj->...ik", mat1, mat2, precision=precision)
58+
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
5659

5760

58-
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
61+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
5962
def sliced_transposed_product(
6063
mat,
6164
block_size,
65+
axes=(-1,),
6266
precision=lax.Precision.DEFAULT,
6367
):
64-
"""Returns the blocked slices representing a symmetric matrix mat*mat^T.
68+
"""Returns the blocked slices representing a symmetric contraction.
69+
70+
Specifically, the output is a contraction of the input mat with itself, in the
71+
specified axes.
6572
6673
Args:
67-
mat: The matrix for which we will compute mat*mat^T. It does not need to be
68-
square, and may be batched.
74+
mat: The matrix for which we will compute a contraction with itself.
6975
block_size: The size of row blocks to compute.
76+
axes: Axes to use for the contraction.
7077
precision: The precision to use in each computation.
7178
7279
Raises:
7380
ValueError: Raised when the specified block size does not evenly divide
7481
the number of rows of the input mat.
7582
"""
76-
num_rows = mat.shape[-2]
83+
rank = len(mat.shape)
84+
85+
def _make_axis_positive(ax):
86+
assert -rank <= ax < rank
87+
return ax + rank if ax < 0 else ax
88+
89+
positive_axes = [_make_axis_positive(ax) for ax in axes]
90+
assert len(positive_axes) == len(axes)
91+
remaining_axes = set(range(rank)) - set(positive_axes)
92+
assert len(remaining_axes) == 1
93+
remaining_ax = remaining_axes.pop()
94+
95+
num_rows = mat.shape[remaining_ax]
7796
if num_rows % block_size != 0:
7897
raise ValueError(
7998
"The row dimension must be divisible by block_size. "
8099
f"Instead got row dimension={num_rows} and block_size={block_size}."
81100
)
82-
block_rows = [
83-
product_with_transpose(
84-
mat[Ellipsis, i * block_size : (i + 1) * block_size, :],
85-
mat[Ellipsis, 0 : (i + 1) * block_size, :],
86-
precision,
101+
102+
block_rows = []
103+
for i in range(num_rows // block_size):
104+
start_indices = [0] * rank
105+
start_indices[remaining_ax] = i * block_size
106+
107+
slice_sizes = list(mat.shape)
108+
slice_sizes[remaining_ax] = block_size
109+
110+
slice_sizes_full = list(mat.shape)
111+
slice_sizes_full[remaining_ax] = (i + 1) * block_size
112+
113+
block_rows.append(
114+
product_with_transpose(
115+
lax.dynamic_slice(
116+
mat, start_indices=start_indices, slice_sizes=slice_sizes
117+
),
118+
lax.dynamic_slice(
119+
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
120+
),
121+
axes=(axes, axes),
122+
precision=precision,
123+
)
87124
)
88-
for i in range(num_rows // block_size)
89-
]
125+
90126
return SlicedSymmetricMatrix(block_rows=block_rows)
91127

92128

93-
@functools.partial(jax.jit, static_argnames=("block_size", "precision"))
129+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
94130
def sliced_transposed_product_concat(
95131
mat,
96132
block_size,
133+
axes=(-1,),
97134
precision=lax.Precision.DEFAULT,
98135
):
99136
"""Returns the concatenated slices representing mat*mat^T.
@@ -102,14 +139,15 @@ def sliced_transposed_product_concat(
102139
mat: The matrix for which we will compute mat*mat^T. It does not need to be
103140
square, and may be batched.
104141
block_size: The size of row blocks to compute.
142+
axes: Axes to use for the contraction.
105143
precision: The precision to use in each computation.
106144
107145
Raises:
108146
ValueError: Raised when the specified block size does not evenly divide
109147
the number of rows of the input mat.
110148
"""
111149
sliced_symmetric_matrix = sliced_transposed_product(
112-
mat=mat, block_size=block_size, precision=precision
150+
mat=mat, block_size=block_size, axes=axes, precision=precision
113151
)
114152
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
115153

@@ -179,12 +217,13 @@ def materialize_matrix_from_concat(
179217
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
180218

181219

182-
@functools.partial(jax.jit, static_argnames=("alpha", "beta"))
220+
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
183221
def update_sliced_rows(
184222
symmetric_matrix,
185223
mat,
186224
alpha,
187225
beta,
226+
axes=(-1,),
188227
):
189228
"""Implements the blocked equivalent of SYRK.
190229
@@ -197,15 +236,45 @@ def update_sliced_rows(
197236
should match that of symmetric_matrix.
198237
alpha: The weight for the update.
199238
beta: The weight for the original symmetric matrix.
239+
axes: Axes to use for the contraction of the update.
200240
201241
Returns:
202242
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
203243
"""
204244
block_size = symmetric_matrix.block_rows[0].shape[-2]
205-
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
245+
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
206246
return SlicedSymmetricMatrix(
207247
block_rows=[
208248
update * alpha + row * beta
209249
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
210250
]
211251
)
252+
253+
254+
def find_num_blocks(block_rows_concat):
255+
"""Returns the number of (row) blocks representing the concatenated matrix.
256+
257+
For example, an input with dimensions [256, 2560] represents 10 square blocks,
258+
which matches 4 lower-triangular block rows (1+2+3+4). So this function will
259+
return 4.
260+
261+
Use ordinary numpy functions here so that the returned value is static.
262+
263+
Args:
264+
block_rows_concat: The concatenated block array.
265+
266+
Raises:
267+
ValueError: When the dimensions of the matrix do not correspond to a lower
268+
triangular block representation.
269+
"""
270+
# Compute the number of square blocks used to represent the matrix.
271+
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
272+
# Determine the number of block rows by inverting y = x*(x+1)/2.
273+
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
274+
if num_blocks * (num_blocks + 1) / 2 != total_blocks:
275+
raise ValueError(
276+
"Could not determine an appropriate number of blocks for "
277+
"the concatenated matrix."
278+
)
279+
else:
280+
return num_blocks

tools/train/train.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import transformers
3838
import wandb
3939
from datasets import Dataset
40-
from flax.core.frozen_dict import FrozenDict, freeze
40+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
4141
from flax.serialization import from_bytes, to_bytes
4242
from flax.training import train_state
4343
from flax.training.common_utils import onehot
@@ -405,6 +405,12 @@ class TrainingArguments:
405405
default=False,
406406
metadata={"help": "Log model to wandb at `save_steps` frequency."},
407407
)
408+
log_histograms: bool = field(
409+
default=False,
410+
metadata={
411+
"help": "Log parameters and gradients histograms. Slows down training."
412+
},
413+
)
408414

409415
seed_model: int = field(
410416
default=42,
@@ -514,10 +520,22 @@ def update_state_metrics(self, state):
514520

515521
def log(self, metrics, prefix=None):
516522
if jax.process_index() == 0:
517-
log_metrics = {
518-
f"{prefix}/{k}" if prefix is not None else k: v
519-
for k, v in metrics.items()
520-
}
523+
log_metrics = {}
524+
for k, v in metrics.items():
525+
if prefix is not None:
526+
k = f"{prefix}/{k}"
527+
if "_norm" in k:
528+
log_metrics[f"{k}/"] = unfreeze(v)
529+
elif "_hist" in k:
530+
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
531+
v = jax.tree_map(
532+
lambda x: wandb.Histogram(np_histogram=x),
533+
v,
534+
is_leaf=lambda x: isinstance(x, tuple),
535+
)
536+
log_metrics[f"{k}/"] = v
537+
else:
538+
log_metrics[k] = v
521539
wandb.log({**log_metrics, **self.state_dict})
522540

523541

@@ -1024,20 +1042,59 @@ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
10241042
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
10251043
)
10261044

1027-
# update state
10281045
grads = with_sharding_constraint(grads, param_spec)
1046+
1047+
# update state
10291048
state = state.apply_gradients(
10301049
grads=grads,
10311050
dropout_rng=dropout_rng,
10321051
train_time=state.train_time + delta_time,
10331052
train_samples=state.train_samples + batch_size_per_step,
10341053
)
10351054

1055+
# get norm and histogram of grads and params
1056+
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
1057+
1058+
def maybe_fn(fn, val, zeros):
1059+
"""Call fn only if it is a logging step"""
1060+
return jax.lax.cond(
1061+
state.step % training_args.logging_steps == 0,
1062+
fn,
1063+
lambda _: zeros,
1064+
val,
1065+
)
1066+
1067+
def norm(val):
1068+
return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1069+
1070+
gradients_norm = maybe_fn(norm, grads, zeros_norm)
1071+
params_norm = maybe_fn(norm, state.params, zeros_norm)
1072+
10361073
metrics = {
10371074
"loss": loss,
10381075
"learning_rate": learning_rate_fn(state.step),
1076+
"gradients_norm": gradients_norm,
1077+
"params_norm": params_norm,
10391078
}
10401079

1080+
if training_args.log_histograms:
1081+
zeros_hist = jax.tree_map(
1082+
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1083+
)
1084+
1085+
def histogram(val):
1086+
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1087+
1088+
gradients_hist = maybe_fn(histogram, grads, zeros_hist)
1089+
params_hist = maybe_fn(histogram, state.params, zeros_hist)
1090+
1091+
metrics.update(
1092+
{
1093+
"params_hist": params_hist,
1094+
"gradients_hist": gradients_hist,
1095+
}
1096+
)
1097+
10411098
return state, metrics
10421099

10431100
# Define eval fn

0 commit comments

Comments
 (0)