"""
network_visualizer.py
Description: This package is meant to help with the visualization of network-
connection data, in conjunction with network_views.py. A good example of the
type of data this is meant to visualize is at:
http://networkx.lanl.gov/index.html
The `Graph` class is meant as a wrapper for `networkx.Graph` objects, which can
then act as a model for `AbstractNetworkView` and the associated network view
classes defined in network_views.py.
Copyright Schrodinger, LLC. All rights reserved.
"""
#Author: Pat Lorton
import copy
import math
from sys import maxsize
import networkx as nx
import numpy as np
from schrodinger.Qt.QtCore import QObject
from schrodinger.Qt.QtCore import pyqtSignal
#### Spring layout parameters ####
STARTING_TEMPERATURE = .1
ITERATIONS = 100
SCALE = 3.0 # Scale of distances to node size. Higher number is greater separation
PUSH = 1 # Multiplier for node-node repulsion
PUSHEXP = 6 # Exponential dependence of node-node repulsion
#===============================================================================
# Graph Model Classes
#===============================================================================
[docs]class GraphSignals(QObject):
selectionChanged = pyqtSignal(set, object)
positionChanged = pyqtSignal(set)
nodesChanged = pyqtSignal(set)
nodesAdded = pyqtSignal(set)
nodesDeleted = pyqtSignal(set)
edgesChanged = pyqtSignal(set)
graphChanged = pyqtSignal()
undoPointSet = pyqtSignal()
[docs]class Graph:
"""
A model class for an undirected graph. This wraps around the NetworkX Graph
class and provides QT signals, a easier-to-use API, and access control.
All persistent data should be stored in self._ggraph.
Note that Graph itself cannot be pickled; Graph has Graph.signals, which is
a QObject and cannot be pickled. For this reason selection information
(which contains references to Graph) is not placed in self._ggraph, so that
self._ggraph can be pickled.
"""
[docs] def __init__(self, ggraph=None, node_class=None, edge_class=None):
"""
Constructs a new Graph object
:param ggraph: The graph underlying this graph.
:type ggraph: `networkx.Graph`
:param node_class: The class to represent the graph's nodes (should be
subclass of `Node`)
:type node_class: class
:param edge_class: The class to represent the graph's edges (should be
subclass of `Edge`)
:type edge_class: class
"""
self.signals = GraphSignals()
if ggraph is None:
ggraph = nx.Graph()
self.node_class = node_class or Node
self.edge_class = edge_class or Edge
self._ggraph = ggraph
self.selected_nodes = set()
self.selected_edges = set()
self.node_objects = {}
self._updateNodeMap()
self.connection_validator = None
self.undo_stack = []
self.max_undo_stack = 100
self.redo_stack = []
@property
def ggraph(self):
return self._ggraph
[docs] def update(self):
"""
Update any derived aspects of the graph after changes.
"""
return
def _updateNodeMap(self):
"""
Update the `node_objects` dictionary with any new nodes from the
underlying ggraph.
"""
new_nodes = set()
for gnode in self.ggraph.nodes:
if str(gnode) not in self.node_objects:
node = self.node_class(gnode, self)
self.node_objects[node.name] = node
new_nodes.add(node)
return new_nodes
[docs] def setEdgeValidator(self, validator):
"""
Set an edge validator that will be run when adding edges between nodes.
:param validator: the validator
:type validator: ConnectionValidator
"""
if not isinstance(validator, ConnectionValidator):
raise TypeError('Validator must be a subclass of '
'ConnectionValidator')
self.connection_validator = validator
[docs] def toNetworkX(self):
"""
Return a copy of the underlying NetworkX graph.
"""
return copy.deepcopy(self._ggraph)
[docs] def getData(self, key):
"""
Return the requested item from the graph's data dictionary. Returns None
if the key is not found.
"""
return self._ggraph.graph.get(key)
[docs] def setData(self, key, value, signal=True):
"""
Set the value of an item in the graph's data dictionary.
"""
self._ggraph.graph[key] = value
if signal:
self.signals.graphChanged.emit()
[docs] def isConnected(self):
"""
Checks whether the graph is connected, that is, whether every node is
connected by some path to every other node.
:return: Whether the graph is connected
rtype: bool
"""
return self._ggraph and nx.is_connected(self._ggraph)
#===========================================================================
# Node methods
#===========================================================================
[docs] def nodeCount(self):
"""
:return: the number of nodes in the graph
:rtype: int
"""
return self.ggraph.number_of_nodes()
[docs] def getIsolates(self):
"""
:return: a complete set of nodes in the graph that have degree 0
:rtype: set(Node)
"""
return {self.getNode(gnode) for gnode in nx.isolates(self.ggraph)}
[docs] def getConnectedComponents(self, nodes=None):
"""
Return a set of nodes for each connected component in the graph.
:param nodes: optionally, a set of nodes to filter the returned
components. If provided, this method will only return components for
which at least one node is in `nodes`
:type nodes: set(Node) or NoneType
:return: a generater over each connected component in the graph
:rtype: typing.Generator[set[Node], None, None]
"""
for gnodes in nx.connected_components(self.ggraph):
component_nodes = {self.getNode(gnode) for gnode in gnodes}
if nodes is None or nodes.intersection(component_nodes):
yield component_nodes
[docs] def getNodeConnectedComponent(self, node):
"""
Return a set of nodes that are part of the same connected component as
`node`.
:param node: a node
:type node: Node
:return: a set of nodes connected to `node` through any path of edges
:rtype: set(Node)
"""
return {
self.getNode(gnode)
for gnode in nx.node_connected_component(self.ggraph, node.gnode)
}
def _getNodeInstances(self, objects):
"""
Given a list of objects, return only the `Node` instances among them.
:param objects: a list of objects
:type objects: list(object)
:return: the set of node instances among the supplied list of objects
:rtype: set(Node)
"""
return {obj for obj in objects if isinstance(obj, self.node_class)}
def _getEdgeInstances(self, objects):
"""
Given a list of objects, return only the `Edge` instances among them.
:param objects: a list of objects
:type objects: list(object)
:return: the set of edge instances among the supplied list of objects
:rtype: set(Edge)
"""
return {obj for obj in objects if isinstance(obj, self.edge_class)}
[docs] def getNode(self, node_key):
"""
Retrieve a node via its name. Retrieved nodes are cached, so getting the
same Node again will return the same instance. Returns None if no
matching Node exists.
:param node_key: a node, gnode, or string that corresponds to the
desired node
:type node_key: object
:return: a node if found, else `None`
:rtype: Node or NoneType
"""
if isinstance(node_key, self.node_class):
return node_key
return self.node_objects.get(str(node_key))
[docs] def getNodes(self, node_keys=None):
"""
Retrieve a set of nodes optionally indicated by a list of keys. If none
is provided, return all nodes.
:param node_keys: optionally, a list of nodes, gnodes, or strings that
correspond to the desired nodes
:type node_keys: list(object) or NoneType
:return: a set of nodes
:rtype: set(Node)
"""
if node_keys is None:
return set(self.node_objects.values())
nodes = set()
for node_key in node_keys:
node = self.getNode(node_key)
if node:
nodes.add(node)
return nodes
[docs] def getNeighbors(self, node):
"""
Return a set of all nodes connected to a specified node
:param node: center node
:type node: Node
:return: neighboring nodes
:rtype: set of Node
"""
gnodes = self._ggraph.neighbors(node.gnode)
return self.getNodes(gnodes)
[docs] def addNodes(self, nodes, signal=True):
"""
Add a list of nodes to this graph. The `nodes` argument can either be a
list of `Node` objects or a list of hashable objects that can be used
as new gnodes.
Note that any time a new gnode is created for use in this graph, its
string representation must be unique among the other nodes in this
graph: nodes are keyed in the `node_objects` dictionary by the string
representation of their corresponding gnode.
:param nodes: list of gnodes or nodes
:type nodes: `list(object)` or `list(Node)`
:param signal: whether the `addNodes` signal should be emitted when done
:type signal: bool
:return: a set of added nodes
:rtype: set(Node)
"""
for node in nodes:
if not isinstance(node, Node):
if str(node) in self.node_objects:
msg = ('A node with the same string representation ("{0}")'
' already exists in this graph. New nodes must have'
' a unique string representation.').format(node)
raise ValueError(msg)
new_node = self.node_class(node, self)
else:
new_node = node
if new_node.gnode in self.ggraph.nodes:
msg = 'Node %s already exists in graph.' % new_node.gnode
raise ValueError(msg)
self._ggraph.add_node(new_node.gnode, **new_node.gdata())
new_nodes = self._updateNodeMap()
if signal and new_nodes:
self.signals.nodesAdded.emit(new_nodes)
return new_nodes
[docs] def addNode(self, node, signal=True):
"""
Convenience method for adding a single node to the graph. See
`addNodes()` for full documentation.
:param node: gnode or node
:type node: hashable or `Node`
:param signal: whether the `addNodes` signal should be emitted when done
:type signal: bool
:return: the added node
:rtype: Node
"""
new_nodes = self.addNodes([node], signal=signal)
return new_nodes.pop()
[docs] def removeNodes(self, nodes, signal=True):
"""
Remove specified nodes from the graph and optionally emit a signal.
:param nodes: a list of nodes to be removed
:type nodes: list(Node)
:param signal: whether to emit a `nodesDeleted` signal when done
:type signal: bool
"""
self._ggraph.remove_nodes_from([n.gnode for n in nodes])
for node in nodes:
del self.node_objects[node.name]
if signal:
self.signals.nodesDeleted.emit(set(nodes))
[docs] def removeNode(self, node, signal=True):
"""
Convenience function for removing a single node. See `removeNode()` for
full documentation.
:param node: a gnode or node to remove
:type node: `object` or `Node`
:param signal: whether to emit a `nodesDeleted` signal when done
:type signal: bool
"""
self.removeNodes([node], signal=signal)
[docs] def setMultipleNodePos(self, pos_dict, signal=True):
"""
Set the positions of nodes from a dictionary.
:param pos_dict: A dictionary mapping nodes to (x,y) tuples.
:type pos_dict: dict {Node : (int, int)}
"""
changednodes = set()
for node, pos in pos_dict.items():
node.setPos(pos[0], pos[1], False)
changednodes.add(node)
if signal:
self.signals.positionChanged.emit(changednodes)
#===========================================================================
# Edge methods
#===========================================================================
[docs] def edgeCount(self):
"""
:return: the number of edges in the graph
:rtype: int
"""
return self.ggraph.number_of_edges()
[docs] def hasEdge(self, node1, node2):
"""
Return whether there is an edge between the supplied nodes.
:param node1: a node from this graph
:type node1: Node
:param node2: a node from this graph
:type node2: Node
:return: whether there exists an edge between the two supplied nodes
:rtype: bool
"""
return self._ggraph.has_edge(node1.gnode, node2.gnode)
[docs] def getGEdge(self, node0, node1):
"""
Return the underlying gedge object corresponding to two supplied nodes.
This can be overwritten in subclasses, but the returned class should
define a consistent edge ordering that is independent of the order of
the supplied node parameters.
:param node0: a node
:type node0: Node
:param node1: a node
:type node1: Node
:return: the underlying gedge between the two nodes, if it exists
:rtype: tuple(networkx.Node) or NoneType
"""
if not self.hasEdge(node0, node1):
return None
return tuple(sorted([node0.gnode, node1.gnode]))
[docs] def getEdge(self, node0, node1):
"""
Given two nodes, return the corresponding edge if it exists.
:param node0: a node
:type node0: Node
:param node1: a node
:type node1: Node
:return: the edge connecting the two nodes if it exists
:rtype: Edge or NoneType
"""
gedge = self.getGEdge(node0, node1)
if gedge:
return self.edge_class(gedge, self)
[docs] def getEdges(self, nodes=None):
"""
Return all edges connected to a node or set of nodes. If no node is
specified, all the edges in the graph are returned.
:param nodes: optionally a node or iterable of nodes
:type nodes: `iterable(Node)`, `Node`, or `None`
:return: a set of edges connected to at least one of the supplied nodes,
or a set of all edges if `nodes` is not specified
:rtype: set(Edge)
"""
if nodes is None:
gnodes = None
else:
if not hasattr(nodes, '__iter__'):
nodes = [nodes]
gnodes = [node.gnode for node in nodes]
edges = set()
for gnode1, gnode2 in self.ggraph.edges(gnodes):
node1, node2 = self.getNode(gnode1), self.getNode(gnode2)
edges.add(self.getEdge(node1, node2))
return edges
[docs] def addEdges(self, edge_tuples, signal=True):
"""
Add edges to graph.
:param edge_tuples: list of tuples indicating the edges to add,
containing two gnodes or nodes and an edge attribute dictionary (or
`None`)
:type edge_tuples: list(tuple(Node, Node, dict)) or
list(tuple(Node, Node, None))
:param signal: whether `edgesChanged` signal should be emitted when done
:type signal: bool
"""
new_edges = set()
for node1, node2, data_dict in edge_tuples:
if data_dict is None:
data_dict = {}
if self.hasEdge(node1, node2):
msg = 'Edge {} already exists in graph.'.format((node1, node2))
raise ValueError(msg)
self._ggraph.add_edge(node1.gnode, node2.gnode, **data_dict)
new_edges.add(self.getEdge(node1, node2))
if signal:
self.signals.edgesChanged.emit(new_edges)
[docs] def addEdge(self, node1, node2, signal=True, data=None):
"""
Convenience function to add a single edge to the graph given two nodes.
The order of the nodes does not matter.
:param node1: a gnode or node connected by the edge
:type node1: `object` or `Node`
:param node2: a gnode or node connected by the edge
:type node2: `object` or `Node`
:param signal: whether `edgesChanged` signal should be emitted when done
:type signal: bool
"""
self.addEdges([(node1, node2, data)], signal=signal)
[docs] def removeEdges(self, edges, signal=True):
"""
Removes specified edges from the graph.
:param edges: a list of edges
:type edges: list(Edge)
:param signal: whether `edgesChanged` signal should be emitted when done
:type signal: bool
"""
for edge in edges:
node1, node2 = edge
if not self.hasEdge(node1, node2):
raise ValueError('Edge not found between %s and %s' %
(node1, node2))
self._ggraph.remove_edge(node1.gnode, node2.gnode)
if signal:
self.signals.edgesChanged.emit(set(edges))
[docs] def removeEdge(self, edge, signal=True):
"""
Convenience function to remove a single edge from the graph.
:param edge: an edge
:type edge: Edge
:param signal: whether `edgesChanged` signal should be emitted when done
:type signal: bool
"""
self.removeEdges([edge], signal=signal)
[docs] def getEdgeApproval(self, node1, node2):
"""
Test whether a new edge can be added between two nodes. Doesn't actually
add an edge, just returns whether it is allowable to add.
"""
if self.hasEdge(node1, node2):
return False, "This connection already exists"
if node1 == node2:
return False, "Can't connect a node to itself."
if self.connection_validator:
return self.connection_validator.validate(node1, node2)
return True, "No Problem"
#===========================================================================
# Selection
#===========================================================================
[docs] def selectedNodes(self):
"""
Return the currently selected nodes.
:rtype: set of Nodes
"""
return self.selected_nodes
[docs] def selectedEdges(self):
"""
:return: the set of selected edges
:rtype: set(Edge)
"""
return self.selected_edges
[docs] def setSelectedObjs(self, objs, source=None, signal=True):
"""
Specify the current selection.
:param objs: a list of objects (nodes or edges) to be selected
:type objs: list(Node or Edge)
:param source: the class instance calling this method (used to avoid
infinite recursion when updating selection state)
:type source: object
:param signal: whether to emit a signal when changing selection state
:type signal: bool
"""
nodes = self._getNodeInstances(objs)
edges = self._getEdgeInstances(objs)
if set.symmetric_difference(nodes, self.selectedNodes()):
self.selected_nodes = nodes
if set.symmetric_difference(edges, self.selectedEdges()):
self.selected_edges = edges
items = nodes.union(edges)
if signal:
self.signals.selectionChanged.emit(items, source)
#===========================================================================
# Layout methods
#===========================================================================
[docs] def springLayout(self, signal=True):
"""
Performs a spring layout on the current graph.
"""
node_coord_map = self._getSpringLayoutCoords(iterations=ITERATIONS,
weight_attr=None,
scale=SCALE)
self.setMultipleNodePos(node_coord_map, signal)
def _getSpringLayoutCoords(self,
dim=2,
node_pos_map=None,
fixed_nodes=None,
iterations=50,
weight_attr='weight',
scale=1):
"""
Calculate and return a dictionary mapping nodes to optimally-computed
Cartesian coordinates for each node. Convenience method that wraps
`spring_layout()`.
:param dim: number of dimensions of the layout
:type dim: int
:param node_pos_map: optionally, initial positions for nodes; otherwise,
use random initial positions
:type node_pos_map: dict(Node, tuple(float))
:param fixed_nodes: optionally, a list of nodes to keep fixed at their
initial positions
:type fixed_nodes: list(Node)
:param iterations: number of iterations of spring-force relaxation
:type iterations: int
:param weight_attr: the edge attribute that holds the numerical value
used for the edge weight. If None, then all edge weights are 1.
:type weight_attr: str or None
:param scale: scale factor for positions
:type scale: float
:return: a dictionary mapping nodes to their calculated positions
:rtype: dict(Node, tuple(float))
"""
if fixed_nodes:
fixed_gnodes = [n.gnode for n in fixed_nodes]
else:
fixed_gnodes = None
if node_pos_map:
gnode_pos_map = {n.gnode: pos for n, pos in node_pos_map.items()}
else:
gnode_pos_map = None
node_coords = spring_layout(self.ggraph,
dim=dim,
pos=gnode_pos_map,
fixed=fixed_gnodes,
iterations=iterations,
weight=weight_attr,
scale=scale)
node_coord_map = {}
for name, coords in node_coords.items():
node = self.getNode(name)
node_coord_map[node] = coords
return node_coord_map
[docs] def minCrossingSpringLayout(self,
num_iterations=100,
fixed_nodes=None,
fraction=1.0):
"""
Perform multiple spring layouts and keep the one with the fewest edge
intersections, keeping the original positions if the layout could not be
improved.
:param num_iterations: number of spring layouts to try
:type num_iterations: int
:param fixed: nodes for which the position should be fixed
:type fixed: iterable of Node
:param signal: whether to emit the positionChanged signal
:type signal: bool
:param fraction: stop iterating if no reduction in crossings is found
within this fraction of num_iterations
:type fraction: float
"""
min_crossings = maxsize
best_pos_map = None
edges = self.getEdges()
fixed_pos_map = None
if fixed_nodes is not None:
fixed_pos_map = {node: node.pos() for node in fixed_nodes}
initial_pos_map = fixed_pos_map
# if I have positions, take those as the ones to improve on
if self.hasPositions():
initial_pos_map = {node: node.pos() for node in self.getNodes()}
_, min_crossings = _has_fewer_crossings(edges, initial_pos_map,
maxsize)
best_pos_map = initial_pos_map
if initial_pos_map: # unscale coordinates to preserve location
self._scaleNodeCoords(initial_pos_map, reverse=True)
new_pos_map = initial_pos_map
max_not_better_iters = min(num_iterations, fraction * num_iterations)
not_better_iters = 0
for i in range(num_iterations):
not_better_iters += 1
if min_crossings == 0:
break
if not_better_iters > max_not_better_iters:
break
new_pos_map = self._getSpringLayoutCoords(iterations=ITERATIONS,
node_pos_map=new_pos_map,
fixed_nodes=fixed_nodes,
weight_attr=None,
scale=SCALE)
is_better, crossings = _has_fewer_crossings(edges, new_pos_map,
min_crossings)
if is_better:
not_better_iters = 0
min_crossings = crossings
best_pos_map = new_pos_map
else:
new_pos_map = fixed_pos_map
if initial_pos_map is best_pos_map:
return
self._scaleNodeCoords(best_pos_map)
self.setMultipleNodePos(best_pos_map)
def _scaleNodeCoords(self, pos_dict, reverse=False):
"""
Scales the positions in the pos_dict dictionary by factor or if reverse
is True 1/factor, where factor = 0.5 x sqrt(NNodes)/2.
Through manual testing 0.5 was determined to be a good multiplier.
:param pos_dict: A dictionary mapping nodes or node names to (x,y)
tuples.
:type pos_dict: dict {Node : (int, int)}
:param reverse: Whether to reverse the scaling
:type revers: bool
"""
num_nodes = len(self._ggraph.nodes) or 1
scale = 0.5 * math.sqrt(num_nodes)
if reverse:
scale = 1.0 / scale
self._scaleDictPositions(pos_dict, scale)
@staticmethod
def _scaleDictPositions(pos_dict, factor):
"""
Multiplies the positions in {node: (x_pos, y_pos)} dictionary by factor.
:param pos_dict: A dictionary mapping nodes to (x,y) tuples.
:type pos_dict: dict {Node : (int, int)}
:param factor: multiplication factor for positions
:type factor: float
"""
for node, xy in pos_dict.items():
scaled_x_pos = xy[0] * factor
scaled_y_pos = xy[1] * factor
pos_dict[node] = [scaled_x_pos, scaled_y_pos]
return pos_dict
[docs] def hasPositions(self, accept_partial=False):
"""
Determines whether the nodes in this graph have x-y coordinates.
:param accept_partial: if set to True, the method will check whether
at least one node has coordinates. Otherwise it requires that all nodes
have coordinates.
:type accept_partial: bool
"""
fully_positioned = True
for node in self.getNodes():
if node.pos() is None:
fully_positioned = False
else:
if accept_partial:
return True
return fully_positioned
#===========================================================================
# Undo/redo
#===========================================================================
[docs] def getState(self):
"""
Get the current state of the Graph
"""
ggraph = copy.deepcopy(self._ggraph)
node_objects = self.node_objects.copy()
return ggraph, node_objects
[docs] def setState(self, state):
"""
Set the current state of the Graph
"""
ggraph, node_objects = state
self._ggraph = ggraph
self.node_objects = node_objects
self.selected_nodes = set()
self.selected_edges = set()
self.signals.graphChanged.emit()
[docs] def setUndoPoint(self, signal=True):
"""
Store the current state to the undo stack. Also wipes out the redo
stack.
"""
self.undo_stack.append(self.getState())
while len(self.undo_stack) > self.max_undo_stack:
self.undo_stack.pop(0)
self.redo_stack = []
if signal:
self.signals.undoPointSet.emit()
[docs] def undo(self):
"""
Revert to the last state on the undo stack.
"""
if not self.undo_stack:
return
self.redo_stack.append(self.getState())
while len(self.redo_stack) > self.max_undo_stack:
self.redo_stack.pop(0)
state = self.undo_stack.pop()
self.setState(state)
[docs] def redo(self):
"""
Undo the undo
"""
if not self.redo_stack:
return
self.undo_stack.append(self.getState())
state = self.redo_stack.pop()
self.setState(state)
[docs] def clearUndoHistory(self):
"""
Clears both undo and redo stacks
"""
self.undo_stack = []
self.redo_stack = []
[docs] def merge(self, g):
"""
Merge data from another graph into this graph. Nodes with
duplicate names will be considered to be the same ligand.
:param g: graph from which data is being merged.
:type g: `Graph`
"""
for edge in g.getEdges():
data_dict = edge.data()
n1, n2 = edge
if 'direction' not in data_dict:
hex1, hex2 = n1.name, n2.name
d = (hex1, hex2) if hex1 < hex2 else (hex2, hex1)
edge.setData('direction', d)
self._ggraph.add_nodes_from(g._ggraph.nodes(data=True))
self._ggraph.add_edges_from(g._ggraph.edges(data=True))
[docs] def deleteSelectedItems(self, include_edges=True, include_nodes=True):
"""
Delete selected nodes and/or selected edges.
:param include_edges: whether selected edges should be deleted
:type include_edges: bool
:param include_nodes: whether selected nodes should be deleted
:type include_nodes: bool
"""
nodes = self.selectedNodes() if include_nodes else set()
edges = self.selectedEdges() if include_edges else set()
if not nodes and not edges:
return
self.setUndoPoint()
self.setSelectedObjs([])
self.deleteItems(nodes, edges)
[docs] def deleteItems(self, nodes=None, edges=None):
"""
Delete specified nodes and edges from the FEP map.
:param nodes: nodes to delete
:type nodes: Set[Node]
:param edges: edges to delete
:type edges: Set[Tuple[Node, Node]]
"""
nodes = nodes or set()
edges = edges or set()
connected_edges = set(self.getEdges(nodes))
edges = edges.union(connected_edges)
if edges:
self.removeEdges(edges)
if nodes:
self.removeNodes(nodes)
if edges or nodes:
self.update()
[docs]class Node:
"""
Model class for Node. Wraps the NetworkX Graph.node dictionary.
"""
x_key = 'storedX'
y_key = 'storedY'
[docs] def __init__(self, name, graph=None):
"""
Construct a Node object. Most of the time, this will be constructed
around an existing NetworkX node (i.e. an entry in the
networkx.Graph.node dict). If a graph is specified, a node of the same
name must exist in the graph, or a ValueError will result.
QT signals will only be emitted if a graph is specified.
:param name: a unique identifier for this node
:type name: hashable
:param graph: the graph object to which this node belongs
:type graph: `Graph`
:ivar _gnode: the underlying graph node that this node wraps. In this
class, we use the node name as the graph node, but any hashable
object can be used.
:ivar _gdata: dictionary that stores data belonging to the underlying
graph node.
"""
gdata = {}
if graph:
try:
gdata = graph.ggraph.nodes.get(name, {})
except KeyError:
raise ValueError('Node %s not found in graph.' % name)
self._gnode = name
self._gdata = gdata
self.graph = graph
@property
def gnode(self):
"""
Return the underlying graph node object wrapped by this `Node` instance
(not the data dictionary `_gdata`).
"""
return self._gnode
@property
def name(self):
"""
Return unique string associated with this node. Convert to string for
subclasses which do not necessarily use strings as graph nodes.
"""
return str(self.gnode)
#===========================================================================
# Positioning
#===========================================================================
[docs] def x(self):
return self._gdata.get(self.x_key, None)
[docs] def y(self):
return self._gdata.get(self.y_key, None)
[docs] def pos(self):
"""
Returns the Node's current position coordinates. Returns None if there
are no coordinates.
:rtype: tuple (float, float)
"""
pos = (self.x(), self.y())
if None in pos:
return None
return pos
[docs] def setX(self, x, signal=True):
if self.x() == x:
return
self._gdata[self.x_key] = x
if signal and self.graph:
self.graph.signals.positionChanged.emit({self})
[docs] def setY(self, y, signal=True):
if self.y() == y:
return
self._gdata[self.y_key] = y
if signal and self.graph:
self.graph.signals.positionChanged.emit({self})
[docs] def setPos(self, x, y, signal=True):
"""
Set the node's position coordinates
:param x: x coordinate
:type x: float
:param y: y coordinate
:type y: float
"""
if self.x() == x and self.y() == y:
return
self.setX(x, False)
self.setY(y, False)
if signal and self.graph:
self.graph.signals.positionChanged.emit({self})
#===========================================================================
# General node properties
#===========================================================================
[docs] def gdata(self):
"""
Directly access the node data dictionary. Use this object carefully, as
directly altering its contents can lead to internal inconsistencies.
This may be wrapped to restrict access.
"""
return self._gdata
[docs] def getData(self, key):
"""
Return the requested item from the node's data dictionary. Returns None
if the key is not found.
"""
return self._gdata.get(key, None)
[docs] def setData(self, key, value, signal=True):
"""
Set the value of an item in the node's data dictionary.
"""
self._gdata[key] = value
if signal and self.graph:
self.graph.signals.nodesChanged.emit({self})
@property
def degree(self):
"""
:return: the degree (number of edges) of the node
:rtype: int
"""
return self.graph.ggraph.degree(self.gnode)
def __repr__(self):
return '<Node("%s")>' % self.name
def __str__(self):
return self.__repr__()
def __eq__(self, rhs):
try:
return id(self.graph) == id(rhs.graph) and self.name == rhs.name
except AttributeError:
return False
def __ne__(self, rhs):
return not self == rhs
def __hash__(self):
return hash((id(self.graph), self.name))
[docs]class Edge:
[docs] def __init__(self, gedge, graph):
"""
:param gedge: the underlying edge object wrapped by this object
:type gedge: object
:param graph: the graph object to which this edge belongs
:type graph: Graph
"""
self._gedge = gedge
self._graph = graph
@property
def gedge(self):
"""
:return: the underlying edge object wrapped by this object
:rtype: fep.graph.Edge
"""
return self._gedge
@property
def graph(self):
"""
:return: the graph to which this edge belongs
:rtype: Graph
"""
return self._graph
@property
def nodes(self):
"""
:return: the nodes connected by this edge in a consistent order, as
determined by the underlying graph edge
:rtype: tuple(Node, Node)
"""
return tuple(self.graph.getNode(gnode) for gnode in self.gedge)
[docs] def data(self):
"""
:return: the data dictionary associated with this edge
:rtype: dict(str, object)
"""
ggraph, gedge = self.graph.ggraph, self.gedge
return dict(ggraph[gedge[0]][gedge[1]])
[docs] def getData(self, key):
"""
Return the requested item from the edge's data dictionary. Returns None
if the key is not found.
:param key: the data item key
:type key: str
:return: the value stored under the specified key in the edge's data
dictionary, or `None` if it is not found
:rtype: object
"""
data_dict = self.data()
return data_dict.get(key)
[docs] def setData(self, key, value, signal=True):
"""
Set the specified item in the edge's data dictionary.
:param key: the data item key
:type key: str
:param value: the value to set for the data item
:type value: object
"""
ggraph, gedge = self.graph.ggraph, self.gedge
data_dict = ggraph[gedge[0]][gedge[1]]
old_value = data_dict.get(key)
data_dict[key] = value
if signal and old_value != value:
self.graph.signals.edgesChanged.emit({self})
@property
def name(self):
"""
:return: the name of the edge, a composite of the connected node names
:rtype: str
"""
node0, node1 = self.nodes
name0 = 'None' if node0 is None else node0.name
name1 = 'None' if node1 is None else node1.name
return f'"{name0}" - "{name1}"'
def __getitem__(self, idx):
"""
Return a node connected by this edge. Only accepts indices 0 and 1.
:param idx: node index
:type idx: `int`
:return: node corresponding to supplied index
:rtype: `LigandNode`
"""
return self.nodes[idx]
def __eq__(self, rhs):
try:
return self.graph == rhs.graph and self.gedge == rhs.gedge
except AttributeError:
return False
def __ne__(self, rhs):
return not self == rhs
def __hash__(self):
return hash((self.graph, self.gedge))
def __str__(self):
return f'<{self.__class__.__name__}({self.name})>'
def __repr__(self):
return self.__str__()
[docs]class ConnectionValidator:
"""
Create a subclass of this and assign it using
NetworkViewer.setConnectionValidator( )
to do extra work making sure node's are compatible to connect.
val1 and val2 are node1.val and node2.val
"""
[docs] def __init__(self):
self.first_node = None
[docs] def validate(self, node1, node2):
return True, "No problem"
[docs] def firstNode(self):
return self.first_node
[docs] def setFirstNode(self, node):
self.first_node = node
[docs] def validateSecondVal(self, val):
if self.firstNode():
return self.validate(self.first_node.val, val)
#===============================================================================
# Network View Classes
#===============================================================================
[docs]class AbstractNetworkView:
"""
A base class for views on Graph models. Use setModel to replace the model
object. Signals from the model are automatically connected to appropriate
synchronization slots.
The abstract view does not provide any built-in support for effecting
changes back into the model (ex. deleting nodes, changing selection). Any
such operations should be implemented in the subclass by making calls
directly to the model. These changes will then be automatically synchronized
forward to all views.
self.nodes is a dictionary mapping model node objects to view node objects.
self.edges is a dictionary mapping pairs of model node objects to view edge
objects. There is no such thing as a edge model object.
Note that all references to the word node and edge in method names refer to
view objects. For example, makeNode() will make a view node, addEdge() will
add an edge view object to the view.
:cvar MODEL_CLASS: an instance of this class will be created as the default
model when `setModel`
:vartype MODEL_CLASS: `Graph` or subclass of `Graph`
:ivar _sync_with_model: whether to automatically synchronize this view (and
its subviews) with the model
:vartype _sync_with_model: bool
"""
MODEL_CLASS = Graph
[docs] def __init__(self):
self.model = None
self.nodes = {}
self.edges = {}
self.skip_selectionChanged = False
self._subviews = set()
self._sync_with_model = True
#===========================================================================
# Model-View Connections
#===========================================================================
[docs] def syncAll(self):
"""
Synchronize the full model and selection state.
"""
model = self.model
self.syncModel()
selection = model.selected_nodes.union(model.selected_edges)
self.syncSelection(selection, model)
[docs] def syncRecursive(self):
"""
Synchronize the full model and selection state on this view and all
subviews.
"""
self.syncAll()
for subview in self._subviews:
subview.syncRecursive()
[docs] def setModelSyncEnabled(self, enable):
"""
Enable or disable automatic synchronization with the model for this view
and all subviews.
"""
for subview in self._subviews:
subview.setModelSyncEnabled(enable)
if self._sync_with_model == enable:
return
self._sync_with_model = enable
if enable:
self._connectSignals()
self.syncAll()
else:
self._disconnectSignals()
[docs] def setModel(self, model):
"""
Set the model for this view and synchronize to it. Any subviews will
have the model set on them as well.
:param model: the graph model
:type model: Graph
"""
if model is None:
model = self.MODEL_CLASS()
if self._sync_with_model:
self._disconnectSignals()
self.model = model
for subview in self._subviews:
subview.setModel(model)
if self._sync_with_model:
self._connectSignals()
self.syncAll()
def _connectSignals(self):
"""
If a model is defined, connect all signal/slot pairs.
"""
if self.model:
for signal, slot in self.getSignalsAndSlots(self.model):
signal.connect(slot)
def _disconnectSignals(self):
"""
If a model is defined, disconnect all signal/slot pairs.
"""
if self.model:
for signal, slot in self.getSignalsAndSlots(self.model):
signal.disconnect(slot)
[docs] def getSignalsAndSlots(self, model):
"""
Get a list of signal/slot pairs for a model. This list will be used when
setting a new model to disconnect the old model signals from their slots
and connect the new model's signals to those slots.
Override this method to modify or extend signals/slots in derived
classes.
:param model: the graph model
:type model: Graph
"""
signals = model.signals
ss_list = [
(signals.graphChanged, self.syncModel),
(signals.nodesAdded, self.syncNodesAdded),
(signals.nodesDeleted, self.syncNodesDeleted),
(signals.nodesChanged, self.syncNodesChanged),
(signals.edgesChanged, self.syncModel),
(signals.selectionChanged, self.syncSelection),
]
return ss_list
[docs] def addSubview(self, subview):
"""
Add a subview to this view. A subview is another AbstractNetworkView
that should always have the same model as its parent view (this view).
Adding will automatically set its model to the current model. Changing
the model on this view will result in all its subviews getting the new
model set
:param subview: the new subview to add to this view
:type subview: AbstractNetworkView
"""
self._subviews.add(subview)
subview.setModel(self.model)
[docs] def removeSubview(self, subview):
"""
Removes the specified subview. The subview is not deleted or altered,
and the model remains set.
:param subview:
:type subview:
"""
self._subviews.remove(subview)
#===========================================================================
# Model-View Synchronization
#===========================================================================
[docs] def syncModel(self):
self.syncNodes()
self.syncEdges()
[docs] def syncNodes(self):
graph = self.model
nodeset = graph.getNodes()
delnodes = set(self.nodes).difference(nodeset)
self.syncNodesAdded(nodeset)
self.syncNodesChanged(nodeset)
self.syncNodesDeleted(delnodes)
[docs] def syncNodesDeleted(self, nodes):
self._removeNodes(nodes)
self.syncEdges()
[docs] def syncNodesAdded(self, nodes):
new_nodes = nodes.difference(set(self.nodes))
self._addNodes(new_nodes)
self.syncEdges()
[docs] def syncNodesChanged(self, nodes):
self.updateNodes(nodes)
if self.edges:
edges = self.model.getEdges(nodes)
self.updateEdges(edges)
[docs] def syncEdges(self):
model_edges = set(self.model.getEdges())
known_edges = set(self.edges)
del_edges = known_edges.difference(model_edges)
self._removeEdges(del_edges)
add_edges = model_edges.difference(known_edges)
self._addEdges(add_edges)
up_edges = model_edges.intersection(known_edges)
self.updateEdges(up_edges)
[docs] def syncSelection(self, selection, source):
if source == self:
return
selected_view_objects = []
for model_obj in selection:
if isinstance(model_obj, Node):
viewnode = self.nodes.get(model_obj)
if viewnode:
selected_view_objects.append(viewnode)
elif isinstance(model_obj, Edge):
viewedge = self.getEdge(model_obj)
if viewedge:
selected_view_objects.append(viewedge)
self.skip_selectionChanged = True
self.selectItems(selected_view_objects)
self.skip_selectionChanged = False
#===========================================================================
# Node operations
#===========================================================================
def _addNodes(self, nodes):
node_map = self.makeNodes(nodes)
self.nodes.update(node_map)
self.addNodes(set(node_map.values()))
def _removeNodes(self, nodes):
viewnodes = [self.getNode(node) for node in nodes]
self.removeNodes(viewnodes)
for node in nodes:
self.nodes.pop(node)
[docs] def makeNodes(self, nodes):
"""
Create new view nodes and return a dictionary mapping supplied model
nodes to corresponding view nodes. Do not add new view nodes to the
view.
By default this method returns an "identity dictionary" that maps nodes
to themselves. Subclasses should override this method to implement their
own view nodes.
:param nodes: model nodes
:type nodes: list(Node)
:return: a dictionary mapping supplied nodes to view nodes
:rtype: dict(Node, object)
"""
return {node: node for node in nodes}
[docs] def makeNode(self, node):
"""
Convenience method for calling `makeNodes()` with a single node. Rather
than returning a dictionary mapping nodes to view nodes, returns the
view node corresponding to the supplied node.
:param node: the model node
:type node: Node
:return: the view node
:rtype: object
"""
node_map = self.makeNodes([node])
return node_map.get(node)
[docs] def addNode(self, viewnode):
"""
A convenience function for calling `addNodes()` for a single node.
:param viewnode: a view node
:type viewnode: object
"""
self.addNodes([viewnode])
[docs] def removeNode(self, viewnode):
"""
Convenience method for calling `removeNode()` for a single node.
:param viewnode: a view node
:type viewnode: object
"""
self.removeNodes([viewnode])
[docs] def updateNode(self, node):
"""
Convenience method for calling `updateNodes()` for a single node.
:param node: the model node to update to
:type node: Node
"""
self.updateNodes([node])
[docs] def getModelNodes(self, node_keys=None):
"""
Retrieve a set of model nodes optionally indicated by a list of keys. If
none is provided, return all nodes.
:param node_keys: optionally, a list of nodes, gnodes, or strings that
correspond to the desired model nodes
:type node_keys: list(object) or NoneType
:return: a set of nodes
:rtype: set(Node)
"""
nodes_in_view = set(list(self.nodes))
nodes_in_model = self.model.getNodes(node_keys)
return nodes_in_view.intersection(nodes_in_model)
[docs] def getNode(self, node):
"""
:param node: a model node
:type node: Node
:return: corresponding view node, if available
:rtype: `object` or `None`
"""
return self.nodes.get(node, None)
#===========================================================================
# Edge operations
#===========================================================================
def _addEdges(self, edges):
for edge in edges:
view_edge = self.getEdge(edge)
if view_edge is not None:
msg = f'A view edge already exists for {edge}.'
raise ValueError(msg)
edge_map = self.makeEdges(edges)
for edge, view_edge in edge_map.items():
self.edges[edge] = view_edge
self.addEdges(list(edge_map.values()))
def _removeEdges(self, edges):
view_edges = [self.getEdge(edge) for edge in edges]
self.removeEdges(view_edges)
for edge in edges:
self.edges.pop(edge)
[docs] def makeEdges(self, edges):
"""
Given a list of model edges, return a dictionary mapping them to
corresponding view edges. Does not add view edges to the view.
By default this method returns an identity dictionary, mapping model
edges to themselves. Subclasses should override this method if they
want to implement their own view edges.
:param edges: a list model nodes
:type nodepairs: list(Edge)
:return: a dictionary mapping model edges to view edges
:rtype: dict(Edge, object)
"""
return {edge: edge for edge in edges}
[docs] def makeEdge(self, edge):
"""
Convenience method for calling `makeEdges()` for a single edge. Rather
than return a dictionary mapping model edges to view edges, returns a
singe view edge. Does not add a view edge to the view.
:param edge: a model edge
:type edge: Edge
:return: a view edge
:rtype: object
"""
edge_map = self.makeEdges([edge])
return edge_map.get(edge)
[docs] def addEdge(self, viewedge):
"""
Convenience method for calling `addEdges()` for a single edge.
:param viewedge: the view edge to add to the view
:type viewedge: object
"""
self.addEdges([viewedge])
[docs] def removeEdge(self, viewedge):
"""
Convenience method for calling `removeEdges()` for a single edge.
:param viewedge: the view edge to remove from the view
:type viewedge: object
"""
self.removeEdges([viewedge])
[docs] def updateEdge(self, edge):
"""
A convenience method for calling `updateEdges()` for a single edge.
:param edge: the model edge corresponding to the view edge to update
:type edge: Edge
"""
self.updateEdges([edge])
[docs] def getModelEdges(self, nodes=None):
"""
Return all model edges connected to a model node or set of model nodes.
If no node is specified, all the edges in the graph are returned. This
method acts like `Graph.getEdges()`, but it filters for model edges that
are available in this view.
:param nodes: optionally a node or list of nodes
:type nodes: `list(Node)`, `Node`, or `None`
:return: a list of model edges
:rtype: list(Edge)
"""
model_edges = set(self.model.getEdges(nodes=nodes))
return list(model_edges.intersection(set(self.edges)))
[docs] def getEdge(self, edge):
"""
Return the view edge corresponding to the supplied model edge.
:param edge: a model edge
:type edge: Edge
:return: the corresponding view edge if available
:rtype: object or None
"""
return self.edges.get(edge)
[docs] def getEdges(self, nodes=None):
"""
Return a list of view edges, filtering the list so that the edges are
connected to the optionally-supplied node or iterable of nodes.
:param nodes: a node or iterable of nodes
:type nodes: iterable[Node] or Node or NoneType
:return: list of view edges
:rtype: list[NetworkEdge or NoneType]
"""
return [self.getEdge(edge) for edge in self.model.getEdges(nodes)]
#===========================================================================
# Pure virtual methods
#===========================================================================
[docs] def addNodes(self, viewnodes):
"""
Takes view nodes and adds them to the view if that makes sense (eg. add
graphics items to scene, add rows to table, etc.) It should not add
the view node to `self.nodes`; that is handled in `_addNodes()`.
:param viewnodes: view nodes to add to the view
:type viewnodes: list(object)
"""
[docs] def removeNodes(self, viewnodes):
"""
Removes view nodes from the view if that makes sense (eg. remove
graphics items from scene, remove table rows, etc.) It should not remove
view nodes from `self.nodes`; that is handled in `_removeNodes()`.
:param viewnodes: a list of view nodes
:type viewnodes: list(object)
"""
[docs] def updateNodes(self, nodes):
"""
Performs any operations necessary to update the view to the current
model state. Note that this method takes model nodes, not view nodes.
:param nodes: model nodes which must have their views updated
:type nodes: list(Node)
"""
[docs] def addEdges(self, viewedges):
"""
Adds view edges to the view. Does not add view edges to `self.edges`.
:param viewedges: view edges to add to the view
:type viewedges: list(object)
"""
[docs] def removeEdges(self, viewedges):
"""
Removes view edges from the view. Does not remove view edges from
`self.edges`.
:param viewedges: view edges to remove from the view
:type viewedges: list(object)
"""
[docs] def updateEdges(self, edges):
"""
Performs any operations necessary to update the view to the current
model state.
:param edges: a list of model edges corresponding to view edges that
should be updated
:type edges: list(Edge)
"""
[docs] def selectItems(self, selected_view_objects):
"""
Selects view objects in the view. Currently only view nodes will be
requested, but may be expanded to allow a combination of nodes and
edges to be selected.
:param selected_view_objects: a list of view objects to be selected
:type selected_view_objects: list(object)
"""
#===============================================================================
# Layout calculations
#===============================================================================
#
# line segment intersection using vectors
# see Computer Graphics by F.S. Hill
#
[docs]def perp(a):
b = np.empty_like(a)
b[0] = -a[1]
b[1] = a[0]
return b
# line segment a given by endpoints a1, a2
# line segment b given by endpoints b1, b2
# return
_eps = 1e-8
[docs]def seg_intersect(a1, a2, b1, b2):
"""
Checks whether two line segments cross each other.
:param a1: first endpoint of line segment a
:type a1: numpy.array
:param a2: second endpoint of line segment a
:type a2: numpy.array
:param b1: first endpoint of line segment b
:type b1: numpy.array
:param b2: second endpoint of line segment b
:type b2: numpy.array
:return: whether the line segments intersect
:rtype: bool
"""
da = a2 - a1
db = b2 - b1
dap = perp(da)
denom = np.dot(dap, db)
if denom == 0: # Line segments are parallel
return False
dp = a1 - b1
num = np.dot(dap, dp)
cx = (num / denom) * db[0] + b1[0] # x-value of intersecting point
# The epsilon is added to account for floating point precision.
return (cx - _eps > min(a1[0], a2[0]) and cx + _eps < max(a1[0], a2[0]) and
cx - _eps > min(b1[0], b2[0]) and cx + _eps < max(b1[0], b2[0]))
def _has_fewer_crossings(edges, node_coords, goal):
"""
Determines whether the graph has less intersections than goal.
"""
np_edges = []
for n1, n2 in edges:
x1, y1 = node_coords[n1]
x2, y2 = node_coords[n2]
p1 = np.array([x1, y1])
p2 = np.array([x2, y2])
np_edges.append((p1, p2))
num_edges = len(np_edges)
crossings = 0
for i in range(num_edges - 1):
for j in range(i + 1, num_edges):
a1, a2 = np_edges[i]
b1, b2 = np_edges[j]
if seg_intersect(a1, a2, b1, b2):
if crossings == goal:
return False, crossings
crossings += 1
return crossings < goal, crossings
#===============================================================================
# Code copied and modified from networkx.drawing.layout
#===============================================================================
[docs]def fruchterman_reingold_layout(G,
dim=2,
pos=None,
fixed=None,
iterations=50,
weight='weight',
scale=1):
"""
Position nodes using Fruchterman-Reingold force-directed algorithm.
:param G: NetworkX graph
:param dim: Dimension of layout
:type dim: int
:param pos: Initial positions for nodes as a dictionary with node as keys
and values as a list or tuple. If None, then use random initial
positions.
:type pos: dict
:param fixed: Nodes to keep fixed at initial position. optional
:type fixed: list
:param iterations: Number of iterations of spring-force relaxation
:type iterations: int
:param weight: The edge attribute that holds the numerical value used for
the edge weight. If None, then all edge weights are 1.
:type weight: str or None
:param scale: Scale factor for positions
:type scale: float
:rtype: dict
:returns: A dictionary of positions keyed by gnode
Examples::
>>> G=nx.path_graph(4)
>>> pos=nx.spring_layout(G)
# The same using longer function name
>>> pos=nx.fruchterman_reingold_layout(G)
"""
if fixed is not None:
gnode_idx_map = {gnode: idx for idx, gnode in enumerate(G)}
fixed = np.asarray([gnode_idx_map[v] for v in fixed])
if pos is not None:
pos_arr = np.asarray(np.random.random((len(G), dim)))
for i, n in enumerate(G):
if n in pos:
pos_arr[i] = np.asarray(pos[n])
else:
pos_arr = None
if len(G) == 0:
return {}
if len(G) == 1:
return {next(iter(G.nodes())): (1,) * dim}
A = nx.to_numpy_matrix(G, weight=weight)
pos = _fruchterman_reingold(A, dim, pos_arr, fixed, iterations)
if fixed is None:
pos = _rescale_layout(pos, scale=scale)
return dict(list(zip(G, pos)))
spring_layout = fruchterman_reingold_layout
def _fruchterman_reingold(A, dim=2, pos=None, fixed=None, iterations=50):
# Position nodes in adjacency matrix A using Fruchterman-Reingold
# Entry point for NetworkX graph is fruchterman_reingold_layout()
try:
import numpy as np
except ImportError:
raise ImportError(
"_fruchterman_reingold() requires numpy: http://scipy.org/ ")
try:
nnodes, _ = A.shape
except AttributeError:
raise nx.NetworkXError(
"fruchterman_reingold() takes an adjacency matrix as input")
A = np.asarray(A) # make sure we have an array instead of a matrix
if pos is None:
# random initial positions
pos = np.asarray(np.random.random((nnodes, dim)), dtype=A.dtype)
else:
# make sure positions are of same type as matrix
pos = pos.astype(A.dtype)
# optimal distance between nodes
k = np.sqrt(1.0 / nnodes)
# the initial "temperature" is about .1 of domain area (=1x1)
# this is the largest step allowed in the dynamics.
t = STARTING_TEMPERATURE
# simple cooling scheme.
# linearly step down by dt on each iteration so last iteration is size dt.
dt = t / (iterations + 1)
delta = np.zeros((pos.shape[0], pos.shape[0], pos.shape[1]), dtype=A.dtype)
# the inscrutable (but fast) version
# this is still O(V^2)
# could use multilevel methods to speed this up significantly
for iteration in range(iterations):
# matrix of difference between points
for i in range(pos.shape[1]):
delta[:, :, i] = pos[:, i, None] - pos[:, i]
# distance between points
distance = np.sqrt((delta**2).sum(axis=-1))
# enforce minimum distance of 0.01
distance = np.where(distance < 0.01, 0.01, distance)
# displacement "force"
displacement = np.transpose(
np.transpose(delta) *
(PUSH * k**PUSHEXP / distance**PUSHEXP - A * distance / k)).sum(
axis=1)
# update positions
length = np.sqrt((displacement**2).sum(axis=1))
length = np.where(length < 0.01, 0.01, length)
delta_pos = np.transpose(np.transpose(displacement) * t / length)
if fixed is not None:
# don't change positions of fixed nodes
delta_pos[fixed] = 0.0
pos += delta_pos
# cool temperature
t -= dt
return pos
def _rescale_layout(pos, scale=1):
# rescale to (0,pscale) in all axes
# shift origin to (0,0)
lim = 0 # max coordinate for all axes
for i in range(pos.shape[1]):
pos[:, i] -= pos[:, i].min()
lim = max(pos[:, i].max(), lim)
# rescale to (0,scale) in all directions, preserves aspect
for i in range(pos.shape[1]):
pos[:, i] *= scale / lim
return pos