|
4 | 4 | from .graph_index import GraphIndex |
5 | 5 | from .batched_graph import BatchedDGLGraph |
6 | 6 |
|
7 | | -__all__ = ['line_graph', 'reverse', 'to_simple_graph'] |
| 7 | +__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected'] |
8 | 8 |
|
9 | 9 |
|
10 | 10 | def line_graph(g, backtracking=True, shared=False): |
@@ -124,4 +124,50 @@ def to_simple_graph(g): |
124 | 124 | newgidx = GraphIndex(_CAPI_DGLToSimpleGraph(g._graph.handle)) |
125 | 125 | return DGLGraph(newgidx, readonly=True) |
126 | 126 |
|
| 127 | +def to_bidirected(g, readonly=True): |
| 128 | + """Convert the graph to a bidirected graph. |
| 129 | +
|
| 130 | + The function generates a new graph with no node/edge feature. |
| 131 | + If g has m edges for i->j and n edges for j->i, then the |
| 132 | + returned graph will have max(m, n) edges for both i->j and j->i. |
| 133 | +
|
| 134 | + Parameters |
| 135 | + ---------- |
| 136 | + g : DGLGraph |
| 137 | + The input graph. |
| 138 | + readonly : bool, default to be True |
| 139 | + Whether the returned bidirected graph is readonly or not. |
| 140 | +
|
| 141 | + Returns |
| 142 | + ------- |
| 143 | + DGLGraph |
| 144 | +
|
| 145 | + Examples |
| 146 | + -------- |
| 147 | + The following two examples use PyTorch backend, one for non-multi graph |
| 148 | + and one for multi-graph. |
| 149 | +
|
| 150 | + >>> # non-multi graph |
| 151 | + >>> g = dgl.DGLGraph() |
| 152 | + >>> g.add_nodes(2) |
| 153 | + >>> g.add_edges([0, 0], [0, 1]) |
| 154 | + >>> bg1 = dgl.to_bidirected(g) |
| 155 | + >>> bg1.edges() |
| 156 | + (tensor([0, 1, 0]), tensor([0, 0, 1])) |
| 157 | +
|
| 158 | + >>> # multi-graph |
| 159 | + >>> g.add_edges([0, 1], [1, 0]) |
| 160 | + >>> g.edges() |
| 161 | + (tensor([0, 0, 0, 1]), tensor([0, 1, 1, 0])) |
| 162 | +
|
| 163 | + >>> bg2 = dgl.to_bidirected(g) |
| 164 | + >>> bg2.edges() |
| 165 | + (tensor([0, 1, 1, 0, 0]), tensor([0, 0, 0, 1, 1])) |
| 166 | + """ |
| 167 | + if readonly: |
| 168 | + newgidx = GraphIndex(_CAPI_DGLToBidirectedImmutableGraph(g._graph.handle)) |
| 169 | + else: |
| 170 | + newgidx = GraphIndex(_CAPI_DGLToBidirectedMutableGraph(g._graph.handle)) |
| 171 | + return DGLGraph(newgidx) |
| 172 | + |
127 | 173 | _init_api("dgl.transform") |
0 commit comments