1616"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
1717
1818import functools
19- from typing import List , Union
19+ from typing import Any , List , Sequence , Union
2020
2121import jax
2222import jax .numpy as jnp
23+ import numpy as np
2324from flax import struct
2425from jax import lax
2526
@@ -41,6 +42,7 @@ class SlicedSymmetricMatrix:
4142def product_with_transpose (
4243 mat1 ,
4344 mat2 ,
45+ axes ,
4446 precision = lax .Precision .DEFAULT ,
4547):
4648 """Returns mat1 * mat2^T for two matrices (possibly batched).
@@ -50,50 +52,85 @@ def product_with_transpose(
5052 Args:
5153 mat1: First matrix.
5254 mat2: Second matrix.
55+ axes: The axes over which to apply the product.
5356 precision: JAX precision to use for the multiplication.
5457 """
55- return jnp .einsum ( "...ij,...kj->...ik" , mat1 , mat2 , precision = precision )
58+ return jnp .tensordot ( a = mat1 , b = mat2 , axes = axes , precision = precision )
5659
5760
58- @functools .partial (jax .jit , static_argnames = ("block_size" , "precision" ))
61+ @functools .partial (jax .jit , static_argnames = ("block_size" , "axes" , " precision" ))
5962def sliced_transposed_product (
6063 mat ,
6164 block_size ,
65+ axes = (- 1 ,),
6266 precision = lax .Precision .DEFAULT ,
6367):
64- """Returns the blocked slices representing a symmetric matrix mat*mat^T.
68+ """Returns the blocked slices representing a symmetric contraction.
69+
70+ Specifically, the output is a contraction of the input mat with itself, in the
71+ specified axes.
6572
6673 Args:
67- mat: The matrix for which we will compute mat*mat^T. It does not need to be
68- square, and may be batched.
74+ mat: The matrix for which we will compute a contraction with itself.
6975 block_size: The size of row blocks to compute.
76+ axes: Axes to use for the contraction.
7077 precision: The precision to use in each computation.
7178
7279 Raises:
7380 ValueError: Raised when the specified block size does not evenly divide
7481 the number of rows of the input mat.
7582 """
76- num_rows = mat .shape [- 2 ]
83+ rank = len (mat .shape )
84+
85+ def _make_axis_positive (ax ):
86+ assert - rank <= ax < rank
87+ return ax + rank if ax < 0 else ax
88+
89+ positive_axes = [_make_axis_positive (ax ) for ax in axes ]
90+ assert len (positive_axes ) == len (axes )
91+ remaining_axes = set (range (rank )) - set (positive_axes )
92+ assert len (remaining_axes ) == 1
93+ remaining_ax = remaining_axes .pop ()
94+
95+ num_rows = mat .shape [remaining_ax ]
7796 if num_rows % block_size != 0 :
7897 raise ValueError (
7998 "The row dimension must be divisible by block_size. "
8099 f"Instead got row dimension={ num_rows } and block_size={ block_size } ."
81100 )
82- block_rows = [
83- product_with_transpose (
84- mat [Ellipsis , i * block_size : (i + 1 ) * block_size , :],
85- mat [Ellipsis , 0 : (i + 1 ) * block_size , :],
86- precision ,
101+
102+ block_rows = []
103+ for i in range (num_rows // block_size ):
104+ start_indices = [0 ] * rank
105+ start_indices [remaining_ax ] = i * block_size
106+
107+ slice_sizes = list (mat .shape )
108+ slice_sizes [remaining_ax ] = block_size
109+
110+ slice_sizes_full = list (mat .shape )
111+ slice_sizes_full [remaining_ax ] = (i + 1 ) * block_size
112+
113+ block_rows .append (
114+ product_with_transpose (
115+ lax .dynamic_slice (
116+ mat , start_indices = start_indices , slice_sizes = slice_sizes
117+ ),
118+ lax .dynamic_slice (
119+ mat , start_indices = [0 ] * rank , slice_sizes = slice_sizes_full
120+ ),
121+ axes = (axes , axes ),
122+ precision = precision ,
123+ )
87124 )
88- for i in range (num_rows // block_size )
89- ]
125+
90126 return SlicedSymmetricMatrix (block_rows = block_rows )
91127
92128
93- @functools .partial (jax .jit , static_argnames = ("block_size" , "precision" ))
129+ @functools .partial (jax .jit , static_argnames = ("block_size" , "axes" , " precision" ))
94130def sliced_transposed_product_concat (
95131 mat ,
96132 block_size ,
133+ axes = (- 1 ,),
97134 precision = lax .Precision .DEFAULT ,
98135):
99136 """Returns the concatenated slices representing mat*mat^T.
@@ -102,14 +139,15 @@ def sliced_transposed_product_concat(
102139 mat: The matrix for which we will compute mat*mat^T. It does not need to be
103140 square, and may be batched.
104141 block_size: The size of row blocks to compute.
142+ axes: Axes to use for the contraction.
105143 precision: The precision to use in each computation.
106144
107145 Raises:
108146 ValueError: Raised when the specified block size does not evenly divide
109147 the number of rows of the input mat.
110148 """
111149 sliced_symmetric_matrix = sliced_transposed_product (
112- mat = mat , block_size = block_size , precision = precision
150+ mat = mat , block_size = block_size , axes = axes , precision = precision
113151 )
114152 return jnp .concatenate (sliced_symmetric_matrix .block_rows , axis = - 1 )
115153
@@ -179,12 +217,13 @@ def materialize_matrix_from_concat(
179217 return materialize_matrix (SlicedSymmetricMatrix (block_rows = block_rows ))
180218
181219
182- @functools .partial (jax .jit , static_argnames = ("alpha" , "beta" ))
220+ @functools .partial (jax .jit , static_argnames = ("alpha" , "beta" , "axes" ))
183221def update_sliced_rows (
184222 symmetric_matrix ,
185223 mat ,
186224 alpha ,
187225 beta ,
226+ axes = (- 1 ,),
188227):
189228 """Implements the blocked equivalent of SYRK.
190229
@@ -197,15 +236,45 @@ def update_sliced_rows(
197236 should match that of symmetric_matrix.
198237 alpha: The weight for the update.
199238 beta: The weight for the original symmetric matrix.
239+ axes: Axes to use for the contraction of the update.
200240
201241 Returns:
202242 The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
203243 """
204244 block_size = symmetric_matrix .block_rows [0 ].shape [- 2 ]
205- sym_prod = sliced_transposed_product (mat = mat , block_size = block_size )
245+ sym_prod = sliced_transposed_product (mat = mat , block_size = block_size , axes = axes )
206246 return SlicedSymmetricMatrix (
207247 block_rows = [
208248 update * alpha + row * beta
209249 for update , row in zip (sym_prod .block_rows , symmetric_matrix .block_rows )
210250 ]
211251 )
252+
253+
254+ def find_num_blocks (block_rows_concat ):
255+ """Returns the number of (row) blocks representing the concatenated matrix.
256+
257+ For example, an input with dimensions [256, 2560] represents 10 square blocks,
258+ which matches 4 lower-triangular block rows (1+2+3+4). So this function will
259+ return 4.
260+
261+ Use ordinary numpy functions here so that the returned value is static.
262+
263+ Args:
264+ block_rows_concat: The concatenated block array.
265+
266+ Raises:
267+ ValueError: When the dimensions of the matrix do not correspond to a lower
268+ triangular block representation.
269+ """
270+ # Compute the number of square blocks used to represent the matrix.
271+ total_blocks = block_rows_concat .shape [- 1 ] / block_rows_concat .shape [- 2 ]
272+ # Determine the number of block rows by inverting y = x*(x+1)/2.
273+ num_blocks = np .round ((np .sqrt (8 * total_blocks + 1 ) - 1 ) / 2 ).astype (np .int32 )
274+ if num_blocks * (num_blocks + 1 ) / 2 != total_blocks :
275+ raise ValueError (
276+ "Could not determine an appropriate number of blocks for "
277+ "the concatenated matrix."
278+ )
279+ else :
280+ return num_blocks
0 commit comments