Implement EdgesConnecting filtering Edges iterator
diff --git a/src/graph_impl/mod.rs b/src/graph_impl/mod.rs
index a8777e5..6ed3d3e 100644
--- a/src/graph_impl/mod.rs
+++ b/src/graph_impl/mod.rs
@@ -862,12 +862,8 @@
pub fn edges_connecting(&self, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> EdgesConnecting<E, Ty, Ix>
{
EdgesConnecting {
- match_end: b,
- edges: &self.edges,
- next: match self.nodes.get(a.index()) {
- None => [EdgeIndex::end(), EdgeIndex::end()],
- Some(n) => n.next,
- },
+ target_node: b,
+ edges: self.edges_directed(a, Direction::Outgoing),
ty: PhantomData,
}
}
@@ -1610,19 +1606,13 @@
}
}
-/// Iterator over the multiple edges between two nodes
+/// Iterator over the multiple directed edges connecting a source node to a target node
pub struct EdgesConnecting<'a, E: 'a, Ty, Ix: 'a = DefaultIx>
where Ty: EdgeType,
Ix: IndexType,
{
- /// The end node
- match_end: NodeIndex<Ix>,
-
- edges: &'a [Edge<E, Ix>],
-
- /// Next edge to visit.
- next: [EdgeIndex<Ix>; 2],
-
+ target_node: NodeIndex<Ix>,
+ edges: Edges<'a, E, Ty, Ix>,
ty: PhantomData<Ty>,
}
@@ -1633,33 +1623,12 @@
type Item = EdgeReference<'a, E, Ix>;
fn next(&mut self) -> Option<EdgeReference<'a, E, Ix>> {
- // First any outgoing edges
- while let Some(edge) = self.edges.get(self.next[0].index()) {
- let i = self.next[0].index();
- self.next[0] = edge.next[0];
- // Make sure we only return edges with the right end node.
- if edge.node[1] == self.match_end {
- return Some(EdgeReference {
- index: edge_index(i),
- node: edge.node,
- weight: &edge.weight,
- });
+ while let Some(edge) = self.edges.next() {
+ if edge.node[1] == self.target_node {
+ return Some(edge);
}
}
- // Then incoming edges
- while let Some(edge) = self.edges.get(self.next[1].index()) {
- let i = self.next[1].index();
- self.next[1] = edge.next[1];
- // Make sure we only return edges with the right end node.
- if edge.node[0] == self.match_end {
- return Some(EdgeReference {
- index: edge_index(i),
- node: edge.node,
- weight: &edge.weight,
- });
- }
- }
None
}
}
diff --git a/tests/graph.rs b/tests/graph.rs
index 476793c..a3fd109 100644
--- a/tests/graph.rs
+++ b/tests/graph.rs
@@ -290,16 +290,27 @@
let a = gr.add_node("a");
let b = gr.add_node("b");
let c = gr.add_node("c");
- gr.add_edge(a, b, ());
+
+ let mut connecting_edges = HashSet::new();
+
+ gr.add_edge(a, a, ());
+ connecting_edges.insert(gr.add_edge(a, b, ()));
gr.add_edge(a, c, ());
- gr.add_edge(a, b, ());
- gr.add_edge(b, c, ());
+ gr.add_edge(c, b, ());
+ connecting_edges.insert(gr.add_edge(a, b, ()));
let mut iter = gr.edges_connecting(a, b);
- assert_eq!(EdgeIndex::new(2), iter.next().unwrap().id());
- assert_eq!(EdgeIndex::new(0), iter.next().unwrap().id());
+ let edge_id = iter.next().unwrap().id();
+ assert!(connecting_edges.contains(&edge_id));
+ connecting_edges.remove(&edge_id);
+
+ let edge_id = iter.next().unwrap().id();
+ assert!(connecting_edges.contains(&edge_id));
+ connecting_edges.remove(&edge_id);
+
assert_eq!(None, iter.next());
+ assert!(connecting_edges.is_empty());
}
#[test]