Skip to content
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ branch = true

[tool.ruff]
src = ["src"]
exclude = ["docs/source/conf.py", "tests/test_graph_app.py"]
exclude = ["docs/source/conf.py"]
lint.select = [
"F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
"E", "W", # https://docs.astral.sh/ruff/rules/#pycodestyle-e-w
Expand Down
3 changes: 3 additions & 0 deletions src/dgipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_sources,
)
from .graph_app import generate_app
from .network_graph import create_network, generate_cytoscape

__all__ = [
"get_drugs",
Expand All @@ -24,4 +25,6 @@
"get_drug_applications",
"generate_app",
"get_clinical_trials",
"create_network",
"generate_cytoscape",
]
3 changes: 2 additions & 1 deletion src/dgipy/graph_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _update_cytoscape(app: dash.Dash) -> None:
def update(terms: list | None, search_mode: str) -> dict:
if len(terms) != 0:
interactions = dgidb.get_interactions(terms, search_mode)
network_graph = ng.initalize_network(interactions, terms, search_mode)
network_graph = ng.create_network(interactions, terms, search_mode)
return ng.generate_cytoscape(network_graph)
return {}

Expand Down Expand Up @@ -262,6 +262,7 @@ def update(selected_element: str | dict) -> tuple[list, None]:
if (
selected_element != ""
and selected_element["group"] == "nodes"
and "node_degree" in selected_element["data"]
and selected_element["data"]["node_degree"] != 1
):
neighbor_set = set()
Expand Down
118 changes: 42 additions & 76 deletions src/dgipy/network_graph.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,38 @@
"""Provides functionality to create networkx graphs and pltoly figures for network visualization"""

import networkx as nx
import pandas as pd

LAYOUT_SEED = 7


def initalize_network(
interactions: pd.DataFrame, terms: list, search_mode: str
) -> nx.Graph:
"""Create a networkx graph representing interactions between genes and drugs

:param interactions: DataFrame containing drug-gene interaction data
:param terms: List containing terms used to query interaction data
:param search_mode: String indicating whether query was gene-focused or drug-focused
:return: a networkx graph of drug-gene interactions
"""
def _initalize_network(interactions: dict, terms: list, search_mode: str) -> nx.Graph:
interactions_graph = nx.Graph()
graphed_terms = set()

for index in range(len(interactions["gene_name"]) - 1):
for row in zip(*interactions.values(), strict=True):
row_dict = dict(zip(interactions.keys(), row, strict=True))
if search_mode == "genes":
graphed_terms.add(interactions["gene_name"][index])
graphed_terms.add(row_dict["gene_name"])
if search_mode == "drugs":
graphed_terms.add(interactions["drug_name"][index])
graphed_terms.add(row_dict["drug_name"])
interactions_graph.add_node(
interactions["gene_name"][index],
label=interactions["gene_name"][index],
row_dict["gene_name"],
label=row_dict["gene_name"],
isGene=True,
)
interactions_graph.add_node(
interactions["drug_name"][index],
label=interactions["drug_name"][index],
row_dict["drug_name"],
label=row_dict["drug_name"],
isGene=False,
)
interactions_graph.add_edge(
interactions["gene_name"][index],
interactions["drug_name"][index],
id=interactions["gene_name"][index]
+ " - "
+ interactions["drug_name"][index],
approval=interactions["drug_approved"][index],
score=interactions["interaction_score"][index],
attributes=interactions["interaction_attributes"][index],
sourcedata=interactions["interaction_sources"][index],
pmid=interactions["interaction_pmids"][index],
row_dict["gene_name"],
row_dict["drug_name"],
id=row_dict["gene_name"] + " - " + row_dict["drug_name"],
approval=row_dict["drug_approved"],
score=row_dict["interaction_score"],
attributes=row_dict["interaction_attributes"],
sourcedata=row_dict["interaction_sources"],
pmid=row_dict["interaction_pmids"],
)

graphed_terms = set(terms).difference(graphed_terms)
Expand All @@ -54,61 +42,34 @@ def initalize_network(
if search_mode == "drugs":
interactions_graph.add_node(term, label=term, isGene=False)

nx.set_node_attributes(
interactions_graph, dict(interactions_graph.degree()), "node_degree"
)
return interactions_graph


def _add_node_attributes(interactions_graph: nx.Graph, search_mode: str) -> None:
nx.set_node_attributes(
interactions_graph, dict(interactions_graph.degree()), "node_degree"
)
for node in interactions_graph.nodes:
is_gene = interactions_graph.nodes[node]["isGene"]
degree = interactions_graph.degree[node]
if search_mode == "genes":
if is_gene:
if degree > 1:
set_color = "cyan"
set_size = 10
else:
set_color = "blue"
set_size = 10
else:
if degree > 1:
set_color = "orange"
set_size = 7
else:
set_color = "red"
set_size = 7
if search_mode == "drugs":
if is_gene:
if degree > 1:
set_color = "cyan"
set_size = 7
else:
set_color = "blue"
set_size = 7
else:
if degree > 1:
set_color = "orange"
set_size = 10
else:
set_color = "red"
set_size = 10
interactions_graph.nodes[node]["node_color"] = set_color
interactions_graph.nodes[node]["node_size"] = set_size

if (search_mode == "genes" and (not is_gene)) or (
search_mode == "drugs" and is_gene
):
neighbors = "Group: " + "-".join(list(interactions_graph.neighbors(node)))
interactions_graph.nodes[node]["group"] = neighbors
else:
interactions_graph.nodes[node]["group"] = None


def create_network(
interactions: pd.DataFrame, terms: list, search_mode: str
) -> nx.Graph:
def create_network(interactions: dict, terms: list, search_mode: str) -> nx.Graph:
"""Create a networkx graph representing interactions between genes and drugs

:param interactions: DataFrame containing drug-gene interaction data
:param interactions: Dictionary containing drug-gene interaction data
:param terms: List containing terms used to query interaction data
:param search_mode: String indicating whether query was gene-focused or drug-focused
:return: a networkx graph of drug-gene interactions
"""
interactions_graph = initalize_network(interactions, terms, search_mode)
interactions_graph = _initalize_network(interactions, terms, search_mode)
_add_node_attributes(interactions_graph, search_mode)
return interactions_graph

Expand All @@ -123,10 +84,15 @@ def generate_cytoscape(graph: nx.Graph) -> dict:
cytoscape_data = nx.cytoscape_data(graph)["elements"]
cytoscape_node_data = cytoscape_data["nodes"]
cytoscape_edge_data = cytoscape_data["edges"]
for node in range(len(cytoscape_node_data)):
node_pos = pos[cytoscape_node_data[node]["data"]["id"]]
node_pos = {
"position": {"x": int(node_pos[0].item()), "y": int(node_pos[1].item())}
}
cytoscape_node_data[node].update(node_pos)
groups = set()
for node in cytoscape_node_data:
node_pos = pos[node["data"]["id"]]
node.update({"position": {"x": node_pos[0], "y": node_pos[1]}})
if "group" in node["data"]:
group = node["data"].pop("group")
groups.add(group)
node["data"]["parent"] = group
groups.remove(None)
for group in groups:
cytoscape_node_data.append({"data": {"id": group}})
return cytoscape_node_data + cytoscape_edge_data
2 changes: 0 additions & 2 deletions tests/test_graph_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from dgipy.graph_app import generate_app


Expand Down
17 changes: 17 additions & 0 deletions tests/test_network_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dgipy.dgidb import get_interactions
from dgipy.network_graph import create_network, generate_cytoscape


def test_create_network():
interactions = get_interactions("BRAF")
terms = ["BRAF"]
search_mode = "genes"
assert create_network(interactions, terms, search_mode)


def test_generate_cytoscape():
interactions = get_interactions("BRAF")
terms = ["BRAF"]
search_mode = "genes"
network = create_network(interactions, terms, search_mode)
assert generate_cytoscape(network)