Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,23 @@
dask_cudf = import_optional("dask_cudf")

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")

Tensor = None if isinstance(torch, MissingModule) else torch.Tensor
NdArray = None if isinstance(cupy, MissingModule) else cupy.ndarray
DaskCudfSeries = None if isinstance(dask_cudf, MissingModule) else dask_cudf.Series

TensorType = Union[Tensor, NdArray, cudf.Series, DaskCudfSeries]
NodeType = (
None
if isinstance(torch_geometric, MissingModule)
else torch_geometric.typing.NodeType
)
EdgeType = (
None
if isinstance(torch_geometric, MissingModule)
else torch_geometric.typing.EdgeType
)


class EdgeLayout(Enum):
Expand Down Expand Up @@ -456,6 +467,20 @@ def __construct_graph(
def _edge_types_to_attrs(self) -> dict:
return dict(self.__edge_types_to_attrs)

@property
def node_types(self) -> List[NodeType]:
return list(self.__vertex_type_offsets["type"])

@property
def edge_types(self) -> List[EdgeType]:
return list(self.__edge_types_to_attrs.keys())

def canonical_edge_type_to_numeric(self, etype: EdgeType) -> int:
return np.searchsorted(self.__edge_type_offsets["type"], "__".join(etype))

def numeric_edge_type_to_canonical(self, etype: int) -> EdgeType:
return tuple(self.__edge_type_offsets["type"][etype].split("__"))

@cached_property
def _is_delayed(self):
if self.__graph is None:
Expand Down Expand Up @@ -657,17 +682,15 @@ def _get_vertex_groups_from_sample(
self, nodes_of_interest: TensorType, is_sorted: bool = False
) -> dict:
"""
Given a cudf (NOT dask_cudf) Series of nodes of interest, this
Given a tensor of nodes of interest, this
method a single dictionary, noi_index.

noi_index is the original vertex ids grouped by vertex type.

Example Input: [5, 2, 10, 11, 8]
Output: {'red_vertex': [5, 8], 'blue_vertex': [2], 'green_vertex': [10, 11]}
Example Input: [5, 2, 1, 10, 11, 8]
Output: {'red_vertex': [5, 1, 8], 'blue_vertex': [2], 'green_vertex': [10, 11]}

"""
if not is_sorted:
nodes_of_interest, _ = torch.sort(nodes_of_interest)

noi_index = {}

Expand Down Expand Up @@ -802,8 +825,14 @@ def _get_renumbered_edge_groups_from_sample(

# Create the row entry for this type
src_id_table = noi_index[src_type]
src = torch.searchsorted(src_id_table, sources)
row_dict[pyg_can_edge_type] = src
src_id_map = (
cudf.Series(cupy.asarray(src_id_table), name="src")
.reset_index()
.rename(columns={"index": "new_id"})
.set_index("src")
)
src = src_id_map["new_id"].loc[cupy.asarray(sources)]
row_dict[pyg_can_edge_type] = torch.as_tensor(src.values, device="cuda")

# Get the de-offsetted destinations
destinations = torch.as_tensor(
Expand All @@ -816,8 +845,14 @@ def _get_renumbered_edge_groups_from_sample(

# Create the col entry for this type
dst_id_table = noi_index[dst_type]
dst = torch.searchsorted(dst_id_table, destinations)
col_dict[pyg_can_edge_type] = dst
dst_id_map = (
cudf.Series(cupy.asarray(dst_id_table), name="dst")
.reset_index()
.rename(columns={"index": "new_id"})
.set_index("dst")
)
dst = dst_id_map["new_id"].loc[cupy.asarray(destinations)]
col_dict[pyg_can_edge_type] = torch.as_tensor(dst.values, device="cuda")

return row_dict, col_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __next__(self):
# 'edge_id':'int64',
"edge_type": "int32",
"batch_id": "int32",
# 'hop_id':'int32'
"hop_id": "int32",
}
self.__data = cudf.read_parquet(parquet_path)
self.__data = self.__data[list(columns.keys())].astype(columns)
Expand Down
8 changes: 0 additions & 8 deletions python/cugraph-pyg/cugraph_pyg/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cugraph.utilities.api_tools import experimental_warning_wrapper

from cugraph_pyg.sampler.cugraph_sampler import (
EXPERIMENTAL__CuGraphSampler,
)

CuGraphSampler = experimental_warning_wrapper(EXPERIMENTAL__CuGraphSampler)
Loading