Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 63 additions & 41 deletions src/backend/cpu/kernel/Array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
namespace cpu {
namespace kernel {

/// Clones nodes and update the child pointers
/// Clones node_index_map and update the child pointers
std::vector<std::shared_ptr<common::Node>> cloneNodes(
const std::vector<common::Node *> &nodes,
const std::vector<common::Node *> &node_index_map,
const std::vector<common::Node_ids> &ids) {
using common::Node;
// find all moddims in the tree
std::vector<std::shared_ptr<Node>> node_clones;
node_clones.reserve(nodes.size());
transform(begin(nodes), end(nodes), back_inserter(node_clones),
[](Node *n) { return n->clone(); });
node_clones.reserve(node_index_map.size());
transform(begin(node_index_map), end(node_index_map),
back_inserter(node_clones), [](Node *n) { return n->clone(); });

for (common::Node_ids id : ids) {
auto &children = node_clones[id.id]->m_children;
Expand All @@ -41,7 +41,8 @@ std::vector<std::shared_ptr<common::Node>> cloneNodes(
return node_clones;
}

/// Sets the shape of the buffer nodes under the moddims node to the new shape
/// Sets the shape of the buffer node_index_map under the moddims node to the
/// new shape
void propagateModdimsShape(
std::vector<std::shared_ptr<common::Node>> &node_clones) {
using common::NodeIterator;
Expand All @@ -63,30 +64,61 @@ void propagateModdimsShape(
}
}

/// Removes nodes whos operation matchs a unary operation \p op.
void removeNodeOfOperation(std::vector<std::shared_ptr<common::Node>> &nodes,
std::vector<common::Node_ids> &ids, af_op_t op) {
/// Removes node_index_map whos operation matchs a unary operation \p op.
void removeNodeOfOperation(
std::vector<std::shared_ptr<common::Node>> &node_index_map, af_op_t op) {
using common::Node;

std::vector<std::vector<std::shared_ptr<Node>>::iterator> moddims_loc;
for (size_t nid = 0; nid < nodes.size(); nid++) {
auto &node = nodes[nid];
for (size_t nid = 0; nid < node_index_map.size(); nid++) {
auto &node = node_index_map[nid];

for (int i = 0;
i < Node::kMaxChildren && node->m_children[i] != nullptr; i++) {
if (node->m_children[i]->getOp() == op) {
// replace moddims
auto moddim_node = node->m_children[i];
node->m_children[i] = moddim_node->m_children[0];

int parent_id = ids[nid].id;
int moddim_id = ids[parent_id].child_ids[i];
moddims_loc.emplace_back(begin(nodes) + moddim_id);
}
}
}

for (auto &loc : moddims_loc) { nodes.erase(loc); }
node_index_map.erase(remove_if(begin(node_index_map), end(node_index_map),
[op](std::shared_ptr<Node> &node) {
return node->getOp() == op;
}),
end(node_index_map));
}

/// Returns the cloned output_nodes located in the node_clones array
///
/// This function returns the new cloned version of the output_nodes_ from
/// the node_clones array. If the output node is a moddim node, then it will
/// set the output node to be its first non-moddim node child
template<typename T>
std::vector<TNode<T> *> getClonedOutputNodes(
common::Node_map_t &node_index_map,
const std::vector<std::shared_ptr<common::Node>> &node_clones,
const std::vector<common::Node_ptr> &output_nodes_) {
std::vector<TNode<T> *> cloned_output_nodes;
cloned_output_nodes.reserve(output_nodes_.size());
for (auto &n : output_nodes_) {
TNode<T> *ptr;
if (n->getOp() == af_moddims_t) {
// if the output node is a moddims node, then set the output node
// to be the child of the moddims node. This is necessary because
// we remove the moddim node_index_map from the tree later
int child_index = node_index_map[n->m_children[0].get()];
ptr = static_cast<TNode<T> *>(node_clones[child_index].get());
while (ptr->getOp() == af_moddims_t) {
ptr = static_cast<TNode<T> *>(ptr->m_children[0].get());
}
} else {
int node_index = node_index_map[n.get()];
ptr = static_cast<TNode<T> *>(node_clones[node_index].get());
}
cloned_output_nodes.push_back(ptr);
}
return cloned_output_nodes;
}

template<typename T>
Expand All @@ -100,41 +132,29 @@ void evalMultiple(std::vector<Param<T>> arrays,
af::dim4 odims = arrays[0].dims();
af::dim4 ostrs = arrays[0].strides();

Node_map_t nodes;
Node_map_t node_index_map;
std::vector<T *> ptrs;
std::vector<TNode<T> *> output_nodes;
std::vector<common::Node *> full_nodes;
std::vector<common::Node_ids> ids;

int narrays = static_cast<int>(arrays.size());
ptrs.reserve(narrays);
for (int i = 0; i < narrays; i++) {
ptrs.push_back(arrays[i].get());
output_nodes_[i]->getNodesMap(nodes, full_nodes, ids);
output_nodes_[i]->getNodesMap(node_index_map, full_nodes, ids);
}

auto node_clones = cloneNodes(full_nodes, ids);

for (auto &n : output_nodes_) {
if (n->getOp() == af_moddims_t) {
// if the output node is a moddims node, then set the output node to
// be the child of the moddims node. This is necessary because we
// remove the moddim nodes from the tree later
output_nodes.push_back(static_cast<TNode<T> *>(
node_clones[nodes[n->m_children[0].get()]].get()));
} else {
output_nodes.push_back(
static_cast<TNode<T> *>(node_clones[nodes[n.get()]].get()));
}
}

std::vector<TNode<T> *> cloned_output_nodes =
getClonedOutputNodes<T>(node_index_map, node_clones, output_nodes_);
propagateModdimsShape(node_clones);
removeNodeOfOperation(node_clones, ids, af_moddims_t);
removeNodeOfOperation(node_clones, af_moddims_t);

bool is_linear = true;
for (auto &node : node_clones) { is_linear &= node->isLinear(odims.get()); }

int num_nodes = node_clones.size();
int num_output_nodes = output_nodes.size();
int num_output_nodes = cloned_output_nodes.size();
if (is_linear) {
int num = arrays[0].dims().elements();
int cnum =
Expand All @@ -145,8 +165,9 @@ void evalMultiple(std::vector<Param<T>> arrays,
node_clones[n]->calc(i, lim);
}
for (int n = 0; n < num_output_nodes; n++) {
std::copy(output_nodes[n]->m_val.begin(),
output_nodes[n]->m_val.begin() + lim, ptrs[n] + i);
std::copy(cloned_output_nodes[n]->m_val.begin(),
cloned_output_nodes[n]->m_val.begin() + lim,
ptrs[n] + i);
}
}
} else {
Expand All @@ -170,9 +191,10 @@ void evalMultiple(std::vector<Param<T>> arrays,
node_clones[n]->calc(x, y, z, w, lim);
}
for (int n = 0; n < num_output_nodes; n++) {
std::copy(output_nodes[n]->m_val.begin(),
output_nodes[n]->m_val.begin() + lim,
ptrs[n] + id);
std::copy(
cloned_output_nodes[n]->m_val.begin(),
cloned_output_nodes[n]->m_val.begin() + lim,
ptrs[n] + id);
}
}
}
Expand Down
67 changes: 67 additions & 0 deletions test/moddims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,70 @@ TEST(Moddims, jit) {
gold = moddims(gold, 5, 10);
ASSERT_ARRAYS_EQ(gold, a);
}

TEST(Moddims, JitNested) {
array a = af::constant(1, 5, 5);
array b = moddims(moddims(moddims(a, 25), 1, 5, 5), 5, 5);
array gold = af::constant(1, 5, 5);
gold.eval();
ASSERT_ARRAYS_EQ(gold, b);
}

TEST(Moddims, JitDuplicate) {
array a = af::constant(1, 5, 5);
array b = af::moddims(a, 25);
array c = b + b;

array gold = af::constant(2, 25);
gold.eval();
ASSERT_ARRAYS_EQ(gold, c);
}

TEST(Moddims, JitNestedAndDuplicate) {
array a = af::constant(1, 10, 10);
array b = af::constant(1, 10, 10);
array c = af::constant(2, 100) + moddims(a + b, 100);
array d = moddims(
moddims(af::constant(2, 1, 10, 10) + moddims(c, 1, 10, 10), 100), 10,
10);
array e = d + d;
array gold = af::constant(12, 10, 10);
gold.eval();
ASSERT_ARRAYS_EQ(gold, e);
}

TEST(Moddims, JitTileThenModdims) {
array a = af::constant(1, 10);
array b = tile(a, 1, 10);
array c = moddims(b, 100);
array gold = af::constant(1, 100);
gold.eval();
ASSERT_ARRAYS_EQ(gold, c);
}

TEST(Moddims, JitModdimsThenTiled) {
array a = af::constant(1, 10);
array b = moddims(a, 1, 10);
array c = tile(b, 10);
array gold = af::constant(1, 10, 10);
gold.eval();
ASSERT_ARRAYS_EQ(gold, c);
}

TEST(Moddims, JitTileThenMultipleModdims) {
array a = af::constant(1, 10);
array b = tile(a, 1, 10);
array c = moddims(moddims(b, 100), 10, 10);
array gold = af::constant(1, 10, 10);
gold.eval();
ASSERT_ARRAYS_EQ(gold, c);
}

TEST(Moddims, JitMultipleModdimsThenTiled) {
array a = af::constant(1, 10);
array b = moddims(moddims(a, 1, 10), 1, 1, 10);
array c = tile(b, 10);
array gold = af::constant(1, 10, 1, 10);
gold.eval();
ASSERT_ARRAYS_EQ(gold, c);
}