1717
1818torch = import_optional ("torch" )
1919ops_torch = import_optional ("pylibcugraphops.pytorch" )
20-
21-
22- class BaseConv (torch .nn .Module ):
23- r"""An abstract base class for cugraph-ops nn module."""
24-
25- def __init__ (self ):
26- super ().__init__ ()
27- self ._cached_offsets_fg = None
28-
29- def reset_parameters (self ):
30- r"""Resets all learnable parameters of the module."""
31- raise NotImplementedError
32-
33- def forward (self , * args ):
34- r"""Runs the forward pass of the module."""
35- raise NotImplementedError
36-
37- def pad_offsets (self , offsets : torch .Tensor , size : int ) -> torch .Tensor :
38- r"""Pad zero-in-degree nodes to the end of offsets to reach size. This
39- is used to augment offset tensors from DGL blocks (MFGs) to be
40- compatible with cugraph-ops full-graph primitives."""
41- if self ._cached_offsets_fg is None :
42- self ._cached_offsets_fg = torch .empty (
43- size , dtype = offsets .dtype , device = offsets .device
44- )
45- elif self ._cached_offsets_fg .numel () < size :
46- self ._cached_offsets_fg .resize_ (size )
47-
48- self ._cached_offsets_fg [: offsets .numel ()] = offsets
49- self ._cached_offsets_fg [offsets .numel () : size ] = offsets [- 1 ]
50-
51- return self ._cached_offsets_fg [:size ]
20+ dgl = import_optional ("dgl" )
5221
5322
5423def compress_ids (ids : torch .Tensor , size : int ) -> torch .Tensor :
@@ -63,8 +32,9 @@ def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor:
6332
6433
6534class SparseGraph (object ):
66- r"""A god-class to store different sparse formats needed by cugraph-ops
67- and facilitate sparse format conversions.
35+ r"""A class to create and store different sparse formats needed by
36+ cugraph-ops. It always creates a CSC representation and can provide COO- or
37+ CSR-format if needed.
6838
6939 Parameters
7040 ----------
@@ -89,25 +59,43 @@ class SparseGraph(object):
8959 consists of the sources between `src_indices[cdst_indices[k]]` and
9060 `src_indices[cdst_indices[k+1]]`.
9161
92- dst_ids_is_sorted: bool
93- Whether `dst_ids` has been sorted in an ascending order. When sorted,
94- creating CSC layout is much faster.
62+ values: torch.Tensor, optional
63+ Values on the edges.
64+
65+ is_sorted: bool
66+ Whether the COO inputs (src_ids, dst_ids, values) have been sorted by
67+ `dst_ids` in an ascending order. CSC layout creation is much faster
68+ when sorted.
9569
9670 formats: str or tuple of str, optional
97- The desired sparse formats to create for the graph.
71+ The desired sparse formats to create for the graph. The formats tuple
72+ must include "csc". Default: "csc".
9873
9974 reduce_memory: bool, optional
10075 When set, the tensors are not required by the desired formats will be
101- set to `None`.
76+ set to `None`. Default: True.
10277
10378 Notes
10479 -----
10580 For MFGs (sampled graphs), the node ids must have been renumbered.
10681 """
10782
108- supported_formats = {"coo" : ("src_ids" , "dst_ids" ), "csc" : ("cdst_ids" , "src_ids" )}
109-
110- all_tensors = set (["src_ids" , "dst_ids" , "csrc_ids" , "cdst_ids" ])
83+ supported_formats = {
84+ "coo" : ("_src_ids" , "_dst_ids" ),
85+ "csc" : ("_cdst_ids" , "_src_ids" ),
86+ "csr" : ("_csrc_ids" , "_dst_ids" , "_perm_csc2csr" ),
87+ }
88+
89+ all_tensors = set (
90+ [
91+ "_src_ids" ,
92+ "_dst_ids" ,
93+ "_csrc_ids" ,
94+ "_cdst_ids" ,
95+ "_perm_coo2csc" ,
96+ "_perm_csc2csr" ,
97+ ]
98+ )
11199
112100 def __init__ (
113101 self ,
@@ -116,15 +104,19 @@ def __init__(
116104 dst_ids : Optional [torch .Tensor ] = None ,
117105 csrc_ids : Optional [torch .Tensor ] = None ,
118106 cdst_ids : Optional [torch .Tensor ] = None ,
119- dst_ids_is_sorted : bool = False ,
120- formats : Optional [Union [str , Tuple [str ]]] = None ,
107+ values : Optional [torch .Tensor ] = None ,
108+ is_sorted : bool = False ,
109+ formats : Union [str , Tuple [str ]] = "csc" ,
121110 reduce_memory : bool = True ,
122111 ):
123112 self ._num_src_nodes , self ._num_dst_nodes = size
124- self ._dst_ids_is_sorted = dst_ids_is_sorted
113+ self ._is_sorted = is_sorted
125114
126115 if dst_ids is None and cdst_ids is None :
127- raise ValueError ("One of 'dst_ids' and 'cdst_ids' must be given." )
116+ raise ValueError (
117+ "One of 'dst_ids' and 'cdst_ids' must be given "
118+ "to create a SparseGraph."
119+ )
128120
129121 if src_ids is not None :
130122 src_ids = src_ids .contiguous ()
@@ -148,30 +140,47 @@ def __init__(
148140 )
149141 cdst_ids = cdst_ids .contiguous ()
150142
143+ if values is not None :
144+ values = values .contiguous ()
145+
151146 self ._src_ids = src_ids
152147 self ._dst_ids = dst_ids
153148 self ._csrc_ids = csrc_ids
154149 self ._cdst_ids = cdst_ids
155- self ._perm = None
150+ self ._values = values
151+ self ._perm_coo2csc = None
152+ self ._perm_csc2csr = None
156153
157154 if isinstance (formats , str ):
158155 formats = (formats ,)
159-
160- if formats is not None :
161- for format_ in formats :
162- assert format_ in SparseGraph .supported_formats
163- self .__getattribute__ (f"_create_{ format_ } " )()
164156 self ._formats = formats
165157
158+ if "csc" not in formats :
159+ raise ValueError (
160+ f"{ self .__class__ .__name__ } .formats must contain "
161+ f"'csc', but got { formats } ."
162+ )
163+
164+ # always create csc first
165+ if self ._cdst_ids is None :
166+ if not self ._is_sorted :
167+ self ._dst_ids , self ._perm_coo2csc = torch .sort (self ._dst_ids )
168+ self ._src_ids = self ._src_ids [self ._perm_coo2csc ]
169+ if self ._values is not None :
170+ self ._values = self ._values [self ._perm_coo2csc ]
171+ self ._cdst_ids = compress_ids (self ._dst_ids , self ._num_dst_nodes )
172+
173+ for format_ in formats :
174+ assert format_ in SparseGraph .supported_formats
175+ self .__getattribute__ (f"{ format_ } " )()
176+
166177 self ._reduce_memory = reduce_memory
167178 if reduce_memory :
168179 self .reduce_memory ()
169180
170181 def reduce_memory (self ):
171182 """Remove the tensors that are not necessary to create the desired sparse
172183 formats to reduce memory footprint."""
173-
174- self ._perm = None
175184 if self ._formats is None :
176185 return
177186
@@ -181,38 +190,157 @@ def reduce_memory(self):
181190 for t in SparseGraph .all_tensors .difference (set (tensors_needed )):
182191 self .__dict__ [t ] = None
183192
184- def _create_coo (self ):
193+ def src_ids (self ) -> torch .Tensor :
194+ return self ._src_ids
195+
196+ def cdst_ids (self ) -> torch .Tensor :
197+ return self ._cdst_ids
198+
199+ def dst_ids (self ) -> torch .Tensor :
185200 if self ._dst_ids is None :
186201 self ._dst_ids = decompress_ids (self ._cdst_ids )
202+ return self ._dst_ids
187203
188- def _create_csc (self ):
189- if self ._cdst_ids is None :
190- if not self ._dst_ids_is_sorted :
191- self ._dst_ids , self ._perm = torch .sort (self ._dst_ids )
192- self ._src_ids = self ._src_ids [self ._perm ]
193- self ._cdst_ids = compress_ids (self ._dst_ids , self ._num_dst_nodes )
204+ def csrc_ids (self ) -> torch .Tensor :
205+ if self ._csrc_ids is None :
206+ src_ids , self ._perm_csc2csr = torch .sort (self ._src_ids )
207+ self ._csrc_ids = compress_ids (src_ids , self ._num_src_nodes )
208+ return self ._csrc_ids
194209
195210 def num_src_nodes (self ):
196211 return self ._num_src_nodes
197212
198213 def num_dst_nodes (self ):
199214 return self ._num_dst_nodes
200215
216+ def values (self ):
217+ return self ._values
218+
201219 def formats (self ):
202220 return self ._formats
203221
204- def coo (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
222+ def coo (self ) -> Tuple [torch .Tensor , torch .Tensor , Optional [ torch . Tensor ] ]:
205223 if "coo" not in self .formats ():
206224 raise RuntimeError (
207225 "The SparseGraph did not create a COO layout. "
208- "Set 'formats' to include 'coo' when creating the graph."
226+ "Set 'formats' list to include 'coo' when creating the graph."
209227 )
210- return ( self ._src_ids , self ._dst_ids )
228+ return self .src_ids () , self .dst_ids (), self . _values
211229
212- def csc (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
230+ def csc (self ) -> Tuple [torch .Tensor , torch .Tensor , Optional [ torch . Tensor ] ]:
213231 if "csc" not in self .formats ():
214232 raise RuntimeError (
215233 "The SparseGraph did not create a CSC layout. "
216- "Set 'formats' to include 'csc' when creating the graph."
234+ "Set 'formats' list to include 'csc' when creating the graph."
235+ )
236+ return self .cdst_ids (), self .src_ids (), self ._values
237+
238+ def csr (self ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
239+ if "csr" not in self .formats ():
240+ raise RuntimeError (
241+ "The SparseGraph did not create a CSR layout. "
242+ "Set 'formats' list to include 'csr' when creating the graph."
243+ )
244+ csrc_ids = self .csrc_ids ()
245+ dst_ids = self .dst_ids ()[self ._perm_csc2csr ]
246+ value = self ._values
247+ if value is not None :
248+ value = value [self ._perm_csc2csr ]
249+ return csrc_ids , dst_ids , value
250+
251+
252+ class BaseConv (torch .nn .Module ):
253+ r"""An abstract base class for cugraph-ops nn module."""
254+
255+ def __init__ (self ):
256+ super ().__init__ ()
257+
258+ def reset_parameters (self ):
259+ r"""Resets all learnable parameters of the module."""
260+ raise NotImplementedError
261+
262+ def forward (self , * args ):
263+ r"""Runs the forward pass of the module."""
264+ raise NotImplementedError
265+
266+ def get_cugraph_ops_CSC (
267+ self ,
268+ g : Union [SparseGraph , dgl .DGLHeteroGraph ],
269+ is_bipartite : bool = False ,
270+ max_in_degree : Optional [int ] = None ,
271+ ) -> ops_torch .CSC :
272+ """Create CSC structure needed by cugraph-ops."""
273+
274+ if not isinstance (g , (SparseGraph , dgl .DGLHeteroGraph )):
275+ raise TypeError (
276+ f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
277+ f"'dgl.DGLHeteroGraph', but got '{ type (g )} '."
217278 )
218- return (self ._cdst_ids , self ._src_ids )
279+
280+ # TODO: max_in_degree should default to None in pylibcugraphops
281+ if max_in_degree is None :
282+ max_in_degree = - 1
283+
284+ if isinstance (g , SparseGraph ):
285+ offsets , indices , _ = g .csc ()
286+ else :
287+ offsets , indices , _ = g .adj_tensors ("csc" )
288+
289+ graph = ops_torch .CSC (
290+ offsets = offsets ,
291+ indices = indices ,
292+ num_src_nodes = g .num_src_nodes (),
293+ dst_max_in_degree = max_in_degree ,
294+ is_bipartite = is_bipartite ,
295+ )
296+
297+ return graph
298+
299+ def get_cugraph_ops_HeteroCSC (
300+ self ,
301+ g : Union [SparseGraph , dgl .DGLHeteroGraph ],
302+ num_edge_types : int ,
303+ etypes : Optional [torch .Tensor ] = None ,
304+ is_bipartite : bool = False ,
305+ max_in_degree : Optional [int ] = None ,
306+ ) -> ops_torch .HeteroCSC :
307+ """Create HeteroCSC structure needed by cugraph-ops."""
308+
309+ if not isinstance (g , (SparseGraph , dgl .DGLHeteroGraph )):
310+ raise TypeError (
311+ f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
312+ f"'dgl.DGLHeteroGraph', but got '{ type (g )} '."
313+ )
314+
315+ # TODO: max_in_degree should default to None in pylibcugraphops
316+ if max_in_degree is None :
317+ max_in_degree = - 1
318+
319+ if isinstance (g , SparseGraph ):
320+ offsets , indices , etypes = g .csc ()
321+ if etypes is None :
322+ raise ValueError (
323+ "SparseGraph must have 'values' to create HeteroCSC. "
324+ "Pass in edge types as 'values' when creating the SparseGraph."
325+ )
326+ etypes = etypes .int ()
327+ else :
328+ if etypes is None :
329+ raise ValueError (
330+ "'etypes' is required when creating HeteroCSC "
331+ "from dgl.DGLHeteroGraph."
332+ )
333+ offsets , indices , perm = g .adj_tensors ("csc" )
334+ etypes = etypes [perm ].int ()
335+
336+ graph = ops_torch .HeteroCSC (
337+ offsets = offsets ,
338+ indices = indices ,
339+ edge_types = etypes ,
340+ num_src_nodes = g .num_src_nodes (),
341+ num_edge_types = num_edge_types ,
342+ dst_max_in_degree = max_in_degree ,
343+ is_bipartite = is_bipartite ,
344+ )
345+
346+ return graph
0 commit comments