Skip to content

Commit b2e85bf

Browse files
authored
Update cugraph-dgl conv layers to use improved graph class (#3849)
This PR: - Removes the usage of the deprecated `StaticCSC` and `SampledCSC` - Support creating CSR and storing edge information in SparseGraph - clean up unit tests - Adds GATv2Conv layer - Adds `pylibcugraphops` as a dependency of `cugraph-dgl` conda package Authors: - Tingyu Wang (https://github.com/tingyu66) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Vibhu Jawa (https://github.com/VibhuJawa) - Brad Rees (https://github.com/BradReesWork) URL: #3849
1 parent 5f76161 commit b2e85bf

File tree

17 files changed

+978
-345
lines changed

17 files changed

+978
-345
lines changed

conda/recipes/cugraph-dgl/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ requirements:
2626
- dgl >=1.1.0.cu*
2727
- numba >=0.57
2828
- numpy >=1.21
29+
- pylibcugraphops ={{ version }}
2930
- python
3031
- pytorch
3132

python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313

1414
from .base import SparseGraph
1515
from .gatconv import GATConv
16+
from .gatv2conv import GATv2Conv
1617
from .relgraphconv import RelGraphConv
1718
from .sageconv import SAGEConv
1819
from .transformerconv import TransformerConv
1920

2021
__all__ = [
2122
"SparseGraph",
2223
"GATConv",
24+
"GATv2Conv",
2325
"RelGraphConv",
2426
"SAGEConv",
2527
"TransformerConv",

python/cugraph-dgl/cugraph_dgl/nn/conv/base.py

Lines changed: 195 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,7 @@
1717

1818
torch = import_optional("torch")
1919
ops_torch = import_optional("pylibcugraphops.pytorch")
20-
21-
22-
class BaseConv(torch.nn.Module):
23-
r"""An abstract base class for cugraph-ops nn module."""
24-
25-
def __init__(self):
26-
super().__init__()
27-
self._cached_offsets_fg = None
28-
29-
def reset_parameters(self):
30-
r"""Resets all learnable parameters of the module."""
31-
raise NotImplementedError
32-
33-
def forward(self, *args):
34-
r"""Runs the forward pass of the module."""
35-
raise NotImplementedError
36-
37-
def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
38-
r"""Pad zero-in-degree nodes to the end of offsets to reach size. This
39-
is used to augment offset tensors from DGL blocks (MFGs) to be
40-
compatible with cugraph-ops full-graph primitives."""
41-
if self._cached_offsets_fg is None:
42-
self._cached_offsets_fg = torch.empty(
43-
size, dtype=offsets.dtype, device=offsets.device
44-
)
45-
elif self._cached_offsets_fg.numel() < size:
46-
self._cached_offsets_fg.resize_(size)
47-
48-
self._cached_offsets_fg[: offsets.numel()] = offsets
49-
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]
50-
51-
return self._cached_offsets_fg[:size]
20+
dgl = import_optional("dgl")
5221

5322

5423
def compress_ids(ids: torch.Tensor, size: int) -> torch.Tensor:
@@ -63,8 +32,9 @@ def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor:
6332

6433

6534
class SparseGraph(object):
66-
r"""A god-class to store different sparse formats needed by cugraph-ops
67-
and facilitate sparse format conversions.
35+
r"""A class to create and store different sparse formats needed by
36+
cugraph-ops. It always creates a CSC representation and can provide COO- or
37+
CSR-format if needed.
6838
6939
Parameters
7040
----------
@@ -89,25 +59,43 @@ class SparseGraph(object):
8959
consists of the sources between `src_indices[cdst_indices[k]]` and
9060
`src_indices[cdst_indices[k+1]]`.
9161
92-
dst_ids_is_sorted: bool
93-
Whether `dst_ids` has been sorted in an ascending order. When sorted,
94-
creating CSC layout is much faster.
62+
values: torch.Tensor, optional
63+
Values on the edges.
64+
65+
is_sorted: bool
66+
Whether the COO inputs (src_ids, dst_ids, values) have been sorted by
67+
`dst_ids` in an ascending order. CSC layout creation is much faster
68+
when sorted.
9569
9670
formats: str or tuple of str, optional
97-
The desired sparse formats to create for the graph.
71+
The desired sparse formats to create for the graph. The formats tuple
72+
must include "csc". Default: "csc".
9873
9974
reduce_memory: bool, optional
10075
When set, the tensors are not required by the desired formats will be
101-
set to `None`.
76+
set to `None`. Default: True.
10277
10378
Notes
10479
-----
10580
For MFGs (sampled graphs), the node ids must have been renumbered.
10681
"""
10782

108-
supported_formats = {"coo": ("src_ids", "dst_ids"), "csc": ("cdst_ids", "src_ids")}
109-
110-
all_tensors = set(["src_ids", "dst_ids", "csrc_ids", "cdst_ids"])
83+
supported_formats = {
84+
"coo": ("_src_ids", "_dst_ids"),
85+
"csc": ("_cdst_ids", "_src_ids"),
86+
"csr": ("_csrc_ids", "_dst_ids", "_perm_csc2csr"),
87+
}
88+
89+
all_tensors = set(
90+
[
91+
"_src_ids",
92+
"_dst_ids",
93+
"_csrc_ids",
94+
"_cdst_ids",
95+
"_perm_coo2csc",
96+
"_perm_csc2csr",
97+
]
98+
)
11199

112100
def __init__(
113101
self,
@@ -116,15 +104,19 @@ def __init__(
116104
dst_ids: Optional[torch.Tensor] = None,
117105
csrc_ids: Optional[torch.Tensor] = None,
118106
cdst_ids: Optional[torch.Tensor] = None,
119-
dst_ids_is_sorted: bool = False,
120-
formats: Optional[Union[str, Tuple[str]]] = None,
107+
values: Optional[torch.Tensor] = None,
108+
is_sorted: bool = False,
109+
formats: Union[str, Tuple[str]] = "csc",
121110
reduce_memory: bool = True,
122111
):
123112
self._num_src_nodes, self._num_dst_nodes = size
124-
self._dst_ids_is_sorted = dst_ids_is_sorted
113+
self._is_sorted = is_sorted
125114

126115
if dst_ids is None and cdst_ids is None:
127-
raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.")
116+
raise ValueError(
117+
"One of 'dst_ids' and 'cdst_ids' must be given "
118+
"to create a SparseGraph."
119+
)
128120

129121
if src_ids is not None:
130122
src_ids = src_ids.contiguous()
@@ -148,30 +140,47 @@ def __init__(
148140
)
149141
cdst_ids = cdst_ids.contiguous()
150142

143+
if values is not None:
144+
values = values.contiguous()
145+
151146
self._src_ids = src_ids
152147
self._dst_ids = dst_ids
153148
self._csrc_ids = csrc_ids
154149
self._cdst_ids = cdst_ids
155-
self._perm = None
150+
self._values = values
151+
self._perm_coo2csc = None
152+
self._perm_csc2csr = None
156153

157154
if isinstance(formats, str):
158155
formats = (formats,)
159-
160-
if formats is not None:
161-
for format_ in formats:
162-
assert format_ in SparseGraph.supported_formats
163-
self.__getattribute__(f"_create_{format_}")()
164156
self._formats = formats
165157

158+
if "csc" not in formats:
159+
raise ValueError(
160+
f"{self.__class__.__name__}.formats must contain "
161+
f"'csc', but got {formats}."
162+
)
163+
164+
# always create csc first
165+
if self._cdst_ids is None:
166+
if not self._is_sorted:
167+
self._dst_ids, self._perm_coo2csc = torch.sort(self._dst_ids)
168+
self._src_ids = self._src_ids[self._perm_coo2csc]
169+
if self._values is not None:
170+
self._values = self._values[self._perm_coo2csc]
171+
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes)
172+
173+
for format_ in formats:
174+
assert format_ in SparseGraph.supported_formats
175+
self.__getattribute__(f"{format_}")()
176+
166177
self._reduce_memory = reduce_memory
167178
if reduce_memory:
168179
self.reduce_memory()
169180

170181
def reduce_memory(self):
171182
"""Remove the tensors that are not necessary to create the desired sparse
172183
formats to reduce memory footprint."""
173-
174-
self._perm = None
175184
if self._formats is None:
176185
return
177186

@@ -181,38 +190,157 @@ def reduce_memory(self):
181190
for t in SparseGraph.all_tensors.difference(set(tensors_needed)):
182191
self.__dict__[t] = None
183192

184-
def _create_coo(self):
193+
def src_ids(self) -> torch.Tensor:
194+
return self._src_ids
195+
196+
def cdst_ids(self) -> torch.Tensor:
197+
return self._cdst_ids
198+
199+
def dst_ids(self) -> torch.Tensor:
185200
if self._dst_ids is None:
186201
self._dst_ids = decompress_ids(self._cdst_ids)
202+
return self._dst_ids
187203

188-
def _create_csc(self):
189-
if self._cdst_ids is None:
190-
if not self._dst_ids_is_sorted:
191-
self._dst_ids, self._perm = torch.sort(self._dst_ids)
192-
self._src_ids = self._src_ids[self._perm]
193-
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes)
204+
def csrc_ids(self) -> torch.Tensor:
205+
if self._csrc_ids is None:
206+
src_ids, self._perm_csc2csr = torch.sort(self._src_ids)
207+
self._csrc_ids = compress_ids(src_ids, self._num_src_nodes)
208+
return self._csrc_ids
194209

195210
def num_src_nodes(self):
196211
return self._num_src_nodes
197212

198213
def num_dst_nodes(self):
199214
return self._num_dst_nodes
200215

216+
def values(self):
217+
return self._values
218+
201219
def formats(self):
202220
return self._formats
203221

204-
def coo(self) -> Tuple[torch.Tensor, torch.Tensor]:
222+
def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
205223
if "coo" not in self.formats():
206224
raise RuntimeError(
207225
"The SparseGraph did not create a COO layout. "
208-
"Set 'formats' to include 'coo' when creating the graph."
226+
"Set 'formats' list to include 'coo' when creating the graph."
209227
)
210-
return (self._src_ids, self._dst_ids)
228+
return self.src_ids(), self.dst_ids(), self._values
211229

212-
def csc(self) -> Tuple[torch.Tensor, torch.Tensor]:
230+
def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
213231
if "csc" not in self.formats():
214232
raise RuntimeError(
215233
"The SparseGraph did not create a CSC layout. "
216-
"Set 'formats' to include 'csc' when creating the graph."
234+
"Set 'formats' list to include 'csc' when creating the graph."
235+
)
236+
return self.cdst_ids(), self.src_ids(), self._values
237+
238+
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
239+
if "csr" not in self.formats():
240+
raise RuntimeError(
241+
"The SparseGraph did not create a CSR layout. "
242+
"Set 'formats' list to include 'csr' when creating the graph."
243+
)
244+
csrc_ids = self.csrc_ids()
245+
dst_ids = self.dst_ids()[self._perm_csc2csr]
246+
value = self._values
247+
if value is not None:
248+
value = value[self._perm_csc2csr]
249+
return csrc_ids, dst_ids, value
250+
251+
252+
class BaseConv(torch.nn.Module):
253+
r"""An abstract base class for cugraph-ops nn module."""
254+
255+
def __init__(self):
256+
super().__init__()
257+
258+
def reset_parameters(self):
259+
r"""Resets all learnable parameters of the module."""
260+
raise NotImplementedError
261+
262+
def forward(self, *args):
263+
r"""Runs the forward pass of the module."""
264+
raise NotImplementedError
265+
266+
def get_cugraph_ops_CSC(
267+
self,
268+
g: Union[SparseGraph, dgl.DGLHeteroGraph],
269+
is_bipartite: bool = False,
270+
max_in_degree: Optional[int] = None,
271+
) -> ops_torch.CSC:
272+
"""Create CSC structure needed by cugraph-ops."""
273+
274+
if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)):
275+
raise TypeError(
276+
f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
277+
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
217278
)
218-
return (self._cdst_ids, self._src_ids)
279+
280+
# TODO: max_in_degree should default to None in pylibcugraphops
281+
if max_in_degree is None:
282+
max_in_degree = -1
283+
284+
if isinstance(g, SparseGraph):
285+
offsets, indices, _ = g.csc()
286+
else:
287+
offsets, indices, _ = g.adj_tensors("csc")
288+
289+
graph = ops_torch.CSC(
290+
offsets=offsets,
291+
indices=indices,
292+
num_src_nodes=g.num_src_nodes(),
293+
dst_max_in_degree=max_in_degree,
294+
is_bipartite=is_bipartite,
295+
)
296+
297+
return graph
298+
299+
def get_cugraph_ops_HeteroCSC(
300+
self,
301+
g: Union[SparseGraph, dgl.DGLHeteroGraph],
302+
num_edge_types: int,
303+
etypes: Optional[torch.Tensor] = None,
304+
is_bipartite: bool = False,
305+
max_in_degree: Optional[int] = None,
306+
) -> ops_torch.HeteroCSC:
307+
"""Create HeteroCSC structure needed by cugraph-ops."""
308+
309+
if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)):
310+
raise TypeError(
311+
f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
312+
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
313+
)
314+
315+
# TODO: max_in_degree should default to None in pylibcugraphops
316+
if max_in_degree is None:
317+
max_in_degree = -1
318+
319+
if isinstance(g, SparseGraph):
320+
offsets, indices, etypes = g.csc()
321+
if etypes is None:
322+
raise ValueError(
323+
"SparseGraph must have 'values' to create HeteroCSC. "
324+
"Pass in edge types as 'values' when creating the SparseGraph."
325+
)
326+
etypes = etypes.int()
327+
else:
328+
if etypes is None:
329+
raise ValueError(
330+
"'etypes' is required when creating HeteroCSC "
331+
"from dgl.DGLHeteroGraph."
332+
)
333+
offsets, indices, perm = g.adj_tensors("csc")
334+
etypes = etypes[perm].int()
335+
336+
graph = ops_torch.HeteroCSC(
337+
offsets=offsets,
338+
indices=indices,
339+
edge_types=etypes,
340+
num_src_nodes=g.num_src_nodes(),
341+
num_edge_types=num_edge_types,
342+
dst_max_in_degree=max_in_degree,
343+
is_bipartite=is_bipartite,
344+
)
345+
346+
return graph

0 commit comments

Comments
 (0)