Skip to content

Commit fb9dcc5

Browse files
authored
[Feature] Add API to convert graph to bidirected graph (dmlc#598)
* to_bidirected * to_bidirected * Fix style * Fix * Update * Fix * Fix * Update * Add examples
1 parent a1513f7 commit fb9dcc5

File tree

6 files changed

+175
-1
lines changed

6 files changed

+175
-1
lines changed

docs/source/api/python/transform.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ Transform -- Graph Transformation
1010

1111
line_graph
1212
reverse
13+
to_simple_graph
14+
to_bidirected

include/dgl/graph_op.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ class GraphOp {
144144
* \return a new immutable simple graph with no multi-edge.
145145
*/
146146
static ImmutableGraph ToSimpleGraph(const GraphInterface* graph);
147+
148+
/*!
149+
* \brief Convert the graph to a mutable bidirected graph.
150+
*
151+
* If the original graph has m edges for i -> j and n edges for
152+
* j -> i, the new graph will have max(m, n) edges for both
153+
* i -> j and j -> i.
154+
*
155+
* \param graph The input graph.
156+
* \return a new mutable bidirected graph.
157+
*/
158+
static Graph ToBidirectedMutableGraph(const GraphInterface* graph);
159+
160+
/*!
161+
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable.
162+
* \param graph The input graph.
163+
* \return a new immutable bidirected graph.
164+
*/
165+
static ImmutableGraph ToBidirectedImmutableGraph(const GraphInterface* graph);
147166
};
148167

149168
} // namespace dgl

python/dgl/transform.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .graph_index import GraphIndex
55
from .batched_graph import BatchedDGLGraph
66

7-
__all__ = ['line_graph', 'reverse', 'to_simple_graph']
7+
__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected']
88

99

1010
def line_graph(g, backtracking=True, shared=False):
@@ -124,4 +124,50 @@ def to_simple_graph(g):
124124
newgidx = GraphIndex(_CAPI_DGLToSimpleGraph(g._graph.handle))
125125
return DGLGraph(newgidx, readonly=True)
126126

127+
def to_bidirected(g, readonly=True):
128+
"""Convert the graph to a bidirected graph.
129+
130+
The function generates a new graph with no node/edge feature.
131+
If g has m edges for i->j and n edges for j->i, then the
132+
returned graph will have max(m, n) edges for both i->j and j->i.
133+
134+
Parameters
135+
----------
136+
g : DGLGraph
137+
The input graph.
138+
readonly : bool, default to be True
139+
Whether the returned bidirected graph is readonly or not.
140+
141+
Returns
142+
-------
143+
DGLGraph
144+
145+
Examples
146+
--------
147+
The following two examples use PyTorch backend, one for non-multi graph
148+
and one for multi-graph.
149+
150+
>>> # non-multi graph
151+
>>> g = dgl.DGLGraph()
152+
>>> g.add_nodes(2)
153+
>>> g.add_edges([0, 0], [0, 1])
154+
>>> bg1 = dgl.to_bidirected(g)
155+
>>> bg1.edges()
156+
(tensor([0, 1, 0]), tensor([0, 0, 1]))
157+
158+
>>> # multi-graph
159+
>>> g.add_edges([0, 1], [1, 0])
160+
>>> g.edges()
161+
(tensor([0, 0, 0, 1]), tensor([0, 1, 1, 0]))
162+
163+
>>> bg2 = dgl.to_bidirected(g)
164+
>>> bg2.edges()
165+
(tensor([0, 1, 1, 0, 0]), tensor([0, 0, 0, 1, 1]))
166+
"""
167+
if readonly:
168+
newgidx = GraphIndex(_CAPI_DGLToBidirectedImmutableGraph(g._graph.handle))
169+
else:
170+
newgidx = GraphIndex(_CAPI_DGLToBidirectedMutableGraph(g._graph.handle))
171+
return DGLGraph(newgidx)
172+
127173
_init_api("dgl.transform")

src/graph/graph_apis.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,22 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
572572
*rv = ret;
573573
});
574574

575+
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
576+
.set_body([] (DGLArgs args, DGLRetValue* rv) {
577+
GraphHandle ghandle = args[0];
578+
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
579+
Graph* bgptr = new Graph();
580+
*bgptr = GraphOp::ToBidirectedMutableGraph(ptr);
581+
GraphHandle bghandle = bgptr;
582+
*rv = bghandle;
583+
});
584+
585+
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
586+
.set_body([] (DGLArgs args, DGLRetValue* rv) {
587+
GraphHandle ghandle = args[0];
588+
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
589+
GraphHandle bghandle = GraphOp::ToBidirectedImmutableGraph(ptr).Reset();
590+
*rv = bghandle;
591+
});
592+
575593
} // namespace dgl

src/graph/graph_op.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,4 +315,75 @@ ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) {
315315
return ImmutableGraph(csr);
316316
}
317317

318+
Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
319+
std::unordered_map<int, std::unordered_map<int, int>> n_e;
320+
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
321+
for (const dgl_id_t v : g->SuccVec(u)) {
322+
n_e[u][v]++;
323+
}
324+
}
325+
326+
Graph bg;
327+
bg.AddVertices(g->NumVertices());
328+
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
329+
for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
330+
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
331+
if (new_n_e > 0) {
332+
IdArray us = NewIdArray(new_n_e);
333+
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
334+
std::fill(us_data, us_data + new_n_e, u);
335+
if (u == v) {
336+
bg.AddEdges(us, us);
337+
} else {
338+
IdArray vs = NewIdArray(new_n_e);
339+
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
340+
std::fill(vs_data, vs_data + new_n_e, v);
341+
bg.AddEdges(us, vs);
342+
bg.AddEdges(vs, us);
343+
}
344+
}
345+
}
346+
}
347+
return bg;
348+
}
349+
350+
ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
351+
std::unordered_map<int, std::unordered_map<int, int>> n_e;
352+
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
353+
for (const dgl_id_t v : g->SuccVec(u)) {
354+
n_e[u][v]++;
355+
}
356+
}
357+
358+
std::vector<dgl_id_t> srcs, dsts;
359+
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
360+
std::unordered_set<dgl_id_t> hashmap;
361+
std::vector<dgl_id_t> nbrs;
362+
for (const dgl_id_t v : g->PredVec(u)) {
363+
if (!hashmap.count(v)) {
364+
nbrs.push_back(v);
365+
hashmap.insert(v);
366+
}
367+
}
368+
for (const dgl_id_t v : g->SuccVec(u)) {
369+
if (!hashmap.count(v)) {
370+
nbrs.push_back(v);
371+
hashmap.insert(v);
372+
}
373+
}
374+
for (const dgl_id_t v : nbrs) {
375+
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
376+
for (size_t i = 0; i < new_n_e; ++i) {
377+
srcs.push_back(v);
378+
dsts.push_back(u);
379+
}
380+
}
381+
}
382+
383+
IdArray srcs_array = VecToIdArray(srcs);
384+
IdArray dsts_array = VecToIdArray(dsts);
385+
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()));
386+
return ImmutableGraph(coo);
387+
}
388+
318389
} // namespace dgl

tests/compute/test_transform.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,27 @@ def test_simple_graph():
9595
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
9696
assert eset == set(elist)
9797

98+
def test_bidirected_graph():
99+
def _test(in_readonly, out_readonly):
100+
elist = [(0, 0), (0, 1), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2), (2, 2)]
101+
g = dgl.DGLGraph(elist, readonly=in_readonly)
102+
elist.append((1, 2))
103+
elist = set(elist)
104+
big = dgl.to_bidirected(g, out_readonly)
105+
assert big.number_of_edges() == 10
106+
src, dst = big.edges()
107+
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
108+
assert eset == set(elist)
109+
110+
_test(True, True)
111+
_test(True, False)
112+
_test(False, True)
113+
_test(False, False)
114+
98115
if __name__ == '__main__':
99116
test_line_graph()
100117
test_no_backtracking()
101118
test_reverse()
102119
test_reverse_shared_frames()
103120
test_simple_graph()
121+
test_bidirected_graph()

0 commit comments

Comments
 (0)