Skip to content

Commit ebf0f28

Browse files
authored
force_atlas2 to support nx hypercube_graph (#1779)
Allow `force_atlas2` to support hypercube_graph The example provided in the issue linked to this PR works closes #1767 Authors: - Joseph Nke (https://github.com/jnke2016) Approvers: - Brad Rees (https://github.com/BradReesWork) - Rick Ratzel (https://github.com/rlratzel) URL: #1779
1 parent 7565dc3 commit ebf0f28

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

python/cugraph/layout/force_atlas2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# limitations under the License.
1313

1414
from cugraph.layout import force_atlas2_wrapper
15+
import cugraph
1516

1617

1718
def force_atlas2(
@@ -106,6 +107,7 @@ def on_train_end(self, positions):
106107
GPU data frame of size V containing three columns:
107108
the vertex identifiers and the x and y positions.
108109
"""
110+
input_graph, isNx = cugraph.utilities.check_nx_graph(input_graph)
109111

110112
if pos_list is not None:
111113
if input_graph.renumbered is True:

python/cugraph/utilities/nx_factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -14,6 +14,7 @@
1414
import cugraph
1515
from .utils import import_optional
1616
from cudf import from_pandas
17+
import numpy as np
1718

1819
nx = import_optional("networkx")
1920

@@ -35,6 +36,13 @@ def convert_from_nx(nxG, weight=None):
3536
raise ValueError("nxG does not appear to be a NetworkX graph type")
3637

3738
pdf = nx.to_pandas_edgelist(nxG)
39+
# Convert vertex columns to strings if they are not integers
40+
# This allows support for any vertex input type
41+
if pdf["source"].dtype not in [np.int32, np.int64] or \
42+
pdf["target"].dtype not in [np.int32, np.int64]:
43+
pdf['source'] = pdf['source'].astype(str)
44+
pdf['target'] = pdf['target'].astype(str)
45+
3846
num_col = len(pdf.columns)
3947

4048
if num_col < 2:

0 commit comments

Comments
 (0)