Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,104 @@ class Graph {
return createNodeInternal(Node<T, U...>());
}

// Note:
// The move functions below are unsafe. Use them with caution
// and be sure to call isValid() after each use.

// Move a node from this graph to the destGraph
void moveNode(NodeRef node, Graph<T, U...>* destGraph) {
assert(hasNode(node));
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
if (&(*it) == node) {
std::list<Node<T, U...>>& destNodes = destGraph->nodes_;
destNodes.splice(destNodes.end(), nodes_, it);
nodeRefs_.erase(node);
destGraph->nodeRefs_.insert(node);
break;
}
}
}

// Move an edge from this graph to the destGraph
void moveEdge(EdgeRef edge, Graph<T, U...>* destGraph) {
assert(hasEdge(edge));
assert(destGraph->hasNode(edge->tail()));
assert(destGraph->hasNode(edge->head()));
std::list<Edge<T, U...>>& destEdges = destGraph->edges_;
for (auto it = edges_.begin(); it != edges_.end(); ++it) {
if (&(*it) == edge) {
destEdges.splice(destEdges.end(), edges_, it);
break;
}
}
}

// Move entire subgraph to destGraph.
// Be sure to delete in/out edges from this graph first.
void moveSubgraph(
const Subgraph<T, U...>& subgraph,
Graph<T, U...>* destGraph) {
auto sg = subgraph; // Copy to check that all nodes and edges are matched
std::list<Edge<T, U...>>& destEdges = destGraph->edges_;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
auto node = &(*it);
if (sg.hasNode(node)) {
std::list<Node<T, U...>>& destNodes = destGraph->nodes_;
destNodes.splice(destNodes.end(), nodes_, it--);
nodeRefs_.erase(node);
destGraph->nodeRefs_.insert(node);
sg.removeNode(node);
}
}
for (auto it = edges_.begin(); it != edges_.end(); ++it) {
auto edge = &(*it);
if (sg.hasEdge(edge)) {
assert(destGraph->hasNode(edge->tail()));
assert(destGraph->hasNode(edge->head()));
destEdges.splice(destEdges.end(), edges_, it--);
sg.removeEdge(edge);
}
}
assert(sg.getNodes().size() == 0);
assert(sg.getEdges().size() == 0);
}

// Validates the graph. Returns true if the graph is valid
// and false if any node or edge referenced in the graph
// is not actually present in the graph.
bool isValid() {
for (auto& node : getMutableNodes()) {
for (auto& inEdge : node->getInEdges()) {
if (!hasEdge(inEdge)) {
DEBUG_PRINT("Invalid inEdge %p on node %p\n", inEdge, node);
return false;
}
}
for (auto& outEdge : node->getOutEdges()) {
if (!hasEdge(outEdge)) {
DEBUG_PRINT("invalid outEdge %p on node %p\n", outEdge, node);
return false;
}
}
// Check validity of nodeRefs_
if (!hasNode(node)) {
DEBUG_PRINT("Invalid node %p\n", node);
return false;
}
}
for (auto& edge : getMutableEdges()) {
if (!hasNode(edge->tail())) {
DEBUG_PRINT("Invalid tail on edge %p\n", edge);
return false;
}
if (!hasNode(edge->head())) {
DEBUG_PRINT("Invalid head on edge %p\n", edge);
return false;
}
}
return true;
}

// Swap two nodes.
// Any edge V -> N1 becomes V -> N2, and N1 -> V becomes N2 -> V.
void swapNodes(NodeRef n1, NodeRef n2) {
Expand Down Expand Up @@ -334,6 +432,15 @@ class Graph {
return getEdgeIfExists(tail, head);
}

bool hasEdge(EdgeRef e) const {
for (auto& edge : edges_) {
if (e == &edge) {
return true;
}
}
return false;
}

/// \brief Get a reference to the edge between two nodes if it exists.
/// note: will fail assertion if the edge does not exist.
EdgeRef getEdge(NodeRef tail, NodeRef head) const {
Expand Down
61 changes: 61 additions & 0 deletions caffe2/core/nomnigraph/tests/GraphTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,65 @@ TEST(Basic, HasNode) {
// Current graph: 3 -> 4 , 2
// replaceNode doesn't delete n2.
EXPECT_TRUE(g.hasNode(n2));

// Create a second graph g2, and move the nodes from g2 to g.
TestClass t5;
nom::Graph<TestClass> g2;
nom::Graph<TestClass>::NodeRef n5 = g2.createNode(std::move(t5));
EXPECT_TRUE(g2.hasNode(n5));

EXPECT_FALSE(g.hasNode(n5));
g2.moveNode(n5, &g);
// Current graph (g1): 3 -> 4, 2, 5
EXPECT_TRUE(g.hasNode(n5));
}

TEST(Basic, Moves) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto e1 = g.createEdge(n1, n2);
auto e2 = g.createEdge(n1, n3);
// Current graph: 1 -> 2 -> 3

TestGraph g2;
g.deleteEdge(e2);
g.moveNode(n1, &g2);
g.moveNode(n2, &g2);
g.moveEdge(e1, &g2);
EXPECT_TRUE(g.isValid());
EXPECT_TRUE(g2.isValid());
EXPECT_EQ(g.getMutableNodes().size(), 1);
EXPECT_EQ(g2.getMutableNodes().size(), 2);
EXPECT_EQ(g.getMutableEdges().size(), 0);
EXPECT_EQ(g2.getMutableEdges().size(), 1);
}

TEST(Basic, MoveSubgraph) {
TestGraph g;
auto n1 = createTestNode(g);
auto n2 = createTestNode(g);
auto n3 = createTestNode(g);
auto e1 = g.createEdge(n1, n2);
auto e2 = g.createEdge(n1, n3);
// Current graph: 1 -> 2 -> 3

TestGraph g2;

g.deleteEdge(e2);

TestGraph::SubgraphType sg;
sg.addNode(n1);
sg.addNode(n2);
sg.addEdge(e1);

g.moveSubgraph(sg, &g2);

EXPECT_TRUE(g.isValid());
EXPECT_TRUE(g2.isValid());
EXPECT_EQ(g.getMutableNodes().size(), 1);
EXPECT_EQ(g2.getMutableNodes().size(), 2);
EXPECT_EQ(g.getMutableEdges().size(), 0);
EXPECT_EQ(g2.getMutableEdges().size(), 1);
}