Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions src/backend/cpu/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,77 @@
#include <queue.hpp>

#include <algorithm>
#include <cassert>
#include <numeric>

using af::dim4;
using arrayfire::common::half;

namespace arrayfire {
namespace cpu {

template<typename T>
Array<T> join(const int dim, const Array<T> &first, const Array<T> &second) {
// All dimensions except join dimension must be equal
Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
// Compute output dims
af::dim4 odims;
af::dim4 fdims = first.dims();
af::dim4 sdims = second.dims();

for (int i = 0; i < 4; i++) {
if (i == dim) {
odims[i] = fdims[i] + sdims[i];
} else {
odims[i] = fdims[i];
}
}

const dim4 &fdims = first.dims();
const dim4 &sdims = second.dims();
// All dimensions except join dimension must be equal
assert((jdim == 0 ? true : fdims.dims[0] == sdims.dims[0]) &&
(jdim == 1 ? true : fdims.dims[1] == sdims.dims[1]) &&
(jdim == 2 ? true : fdims.dims[2] == sdims.dims[2]) &&
(jdim == 3 ? true : fdims.dims[3] == sdims.dims[3]));

// compute output dms
dim4 odims(fdims);
odims.dims[jdim] += sdims.dims[jdim];
Array<T> out = createEmptyArray<T>(odims);
std::vector<CParam<T>> v{first, second};
getQueue().enqueue(kernel::join<T>, dim, out, v, 2);
getQueue().enqueue(kernel::join<T>, jdim, out, v, 2);

return out;
}

template<typename T>
void join(Array<T> &out, const int dim, const std::vector<Array<T>> &inputs) {
void join(Array<T> &out, const int jdim, const std::vector<Array<T>> &inputs) {
const dim_t n_arrays = inputs.size();

std::vector<Array<T> *> input_ptrs(inputs.size());
if (n_arrays == 0) return;

// avoid buffer overflow
const dim4 &odims{out.dims()};
const dim4 &fdims{inputs[0].dims()};
// All dimensions of inputs needs to be equal except for the join
// dimension
assert(std::all_of(inputs.begin(), inputs.end(),
[jdim, &fdims](const Array<T> &in) {
bool eq{true};
for (int i = 0; i < 4; ++i) {
if (i != jdim) {
eq &= fdims.dims[i] == in.dims().dims[i];
};
};
return eq;
}));
// All dimensions of out needs to cover all input dimensions
assert(
(odims.dims[0] >= fdims.dims[0]) && (odims.dims[1] >= fdims.dims[1]) &&
(odims.dims[2] >= fdims.dims[2]) && (odims.dims[3] >= fdims.dims[3]));
// The join dimension of out needs to be larger than the
// sum of all input join dimensions
assert(odims.dims[jdim] >=
std::accumulate(inputs.begin(), inputs.end(), 0,
[jdim](dim_t dim, const Array<T> &in) {
return dim += in.dims()[jdim];
}));
assert(out.strides().dims[0] == 1);

std::vector<Array<T> *> input_ptrs(n_arrays);
std::transform(
begin(inputs), end(inputs), begin(input_ptrs),
[](const Array<T> &input) { return const_cast<Array<T> *>(&input); });
evalMultiple(input_ptrs);
std::vector<CParam<T>> inputParams(inputs.begin(), inputs.end());

getQueue().enqueue(kernel::join<T>, dim, out, inputParams, n_arrays);
getQueue().enqueue(kernel::join<T>, jdim, out, inputParams, n_arrays);
}

#define INSTANTIATE(T) \
Expand Down
66 changes: 58 additions & 8 deletions src/backend/cuda/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include <kernel/memcopy.hpp>

#include <algorithm>
#include <cassert>
#include <map>
#include <numeric>
#include <stdexcept>
#include <vector>

Expand All @@ -29,9 +31,14 @@ namespace cuda {

template<typename T>
Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
// All dimensions except join dimension must be equal
const dim4 &fdims{first.dims()};
const dim4 &sdims{second.dims()};
// All dimensions except join dimension must be equal
assert((jdim == 0 ? true : fdims.dims[0] == sdims.dims[0]) &&
(jdim == 1 ? true : fdims.dims[1] == sdims.dims[1]) &&
(jdim == 2 ? true : fdims.dims[2] == sdims.dims[2]) &&
(jdim == 3 ? true : fdims.dims[3] == sdims.dims[3]));

// Compute output dims
dim4 odims(fdims);
odims.dims[jdim] += sdims.dims[jdim];
Expand Down Expand Up @@ -117,6 +124,42 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {

template<typename T>
void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
const dim_t n_arrays = inputs.size();
if (n_arrays == 0) return;

// avoid buffer overflow
const dim4 &odims{out.dims()};
const dim4 &fdims{inputs[0].dims()};
// All dimensions of inputs needs to be equal except for the join
// dimension
assert(std::all_of(inputs.begin(), inputs.end(),
[jdim, &fdims](const Array<T> &in) {
bool eq{true};
for (int i = 0; i < 4; ++i) {
if (i != jdim) {
eq &= fdims.dims[i] == in.dims().dims[i];
};
};
return eq;
}));
// All dimensions of out needs to cover all input dimensions
assert(
(odims.dims[0] >= fdims.dims[0]) && (odims.dims[1] >= fdims.dims[1]) &&
(odims.dims[2] >= fdims.dims[2]) && (odims.dims[3] >= fdims.dims[3]));
// The join dimension of out needs to be larger than the
// sum of all input join dimensions
assert(odims.dims[jdim] >=
std::accumulate(inputs.begin(), inputs.end(), 0,
[jdim](dim_t dim, const Array<T> &in) {
return dim += in.dims().dims[jdim];
}));
assert(out.strides().dims[0] == 1);

// out is an external defined array:
// - with the only restriction that the dims have to be larger than the
// joined inputs.
// - no restrictions on the strides.
// The part of out, that is not overwritten by the join remains as is!!
class eval {
public:
vector<Param<T>> outputs;
Expand All @@ -126,6 +169,7 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
};
std::map<dim_t, eval> evals;
const cudaStream_t activeStream{getActiveStream()};
const dim4 &ostrides{out.strides()};
const size_t L2CacheSize{getL2CacheSize(getActiveDeviceId())};

// topspeed is achieved when byte size(in+out) ~= L2CacheSize
Expand All @@ -148,18 +192,19 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
// has to be called multiple times

// Group all arrays according to size
dim_t outOffset{0};
dim_t odim{0}, outOffset{0};
for (const Array<T> &iArray : inputs) {
const dim_t *idims{iArray.dims().dims};
eval &e{evals[idims[jdim]]};
e.outputs.emplace_back(out.get() + outOffset, idims,
out.strides().dims);
const dim4 &idims{iArray.dims()};
eval &e{evals[idims.dims[jdim]]};
e.outputs.emplace_back(out.get() + outOffset, idims.dims,
ostrides.dims);
// Extend life of the returned node by saving the corresponding
// shared_ptr
e.nodePtrs.emplace_back(iArray.getNode());
e.nodes.push_back(e.nodePtrs.back().get());
e.ins.push_back(&iArray);
outOffset += idims[jdim] * out.strides().dims[jdim];
odim += idims.dims[jdim];
outOffset = odim * ostrides.dims[jdim];
}

for (auto &eval : evals) {
Expand All @@ -173,7 +218,12 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
auto outputIt{begin(s.outputs)};
for (const Array<T> *in : s.ins) {
if (in->isReady()) {
if (1LL + jdim >= in->ndims() && in->isLinear()) {
const dim4 &istrides{in->strides()};
bool lin = in->isLinear() & (ostrides.dims[0] == 1);
for (int i{1}; i < in->ndims(); ++i) {
lin &= (ostrides.dims[i] == istrides.dims[i]);
}
if (lin) {
CUDA_CHECK(cudaMemcpyAsync(outputIt->ptr, in->get(),
in->elements() * sizeof(T),
cudaMemcpyHostToDevice,
Expand Down
76 changes: 61 additions & 15 deletions src/backend/oneapi/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
// All dimensions except join dimension must be equal
const dim4 &fdims{first.dims()};
const dim4 &sdims{second.dims()};
// All dimensions except join dimension must be equal
assert((jdim == 0 ? true : fdims.dims[0] == sdims.dims[0]) &&
(jdim == 1 ? true : fdims.dims[1] == sdims.dims[1]) &&
(jdim == 2 ? true : fdims.dims[2] == sdims.dims[2]) &&
(jdim == 3 ? true : fdims.dims[3] == sdims.dims[3]));

// Compute output dims
dim4 odims(fdims);
Expand Down Expand Up @@ -69,16 +74,18 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
// Both arrays have same size & everything fits into the cache,
// so thread in 1 JIT kernel, iso individual copies which is
// always slower
const dim_t *outStrides{out.strides().dims};
const dim4 &outStrides{out.strides()};
vector<Param<T>> outputs{
{out.get(),
{{fdims.dims[0], fdims.dims[1], fdims.dims[2], fdims.dims[3]},
{outStrides[0], outStrides[1], outStrides[2], outStrides[3]},
{outStrides.dims[0], outStrides.dims[1], outStrides.dims[2],
outStrides.dims[3]},
0}},
{out.get(),
{{sdims.dims[0], sdims.dims[1], sdims.dims[2], sdims.dims[3]},
{outStrides[0], outStrides[1], outStrides[2], outStrides[3]},
fdims.dims[jdim] * outStrides[jdim]}}};
{outStrides.dims[0], outStrides.dims[1], outStrides.dims[2],
outStrides.dims[3]},
fdims.dims[jdim] * outStrides.dims[jdim]}}};
// Extend the life of the returned node, bij saving the
// corresponding shared_ptr
const Node_ptr fNode{first.getNode()};
Expand Down Expand Up @@ -113,11 +120,12 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
}
} else {
// Write the result directly in the out array
const dim_t *outStrides{out.strides().dims};
const dim4 &outStrides{out.strides()};
Param<T> output{
out.get(),
{{fdims.dims[0], fdims.dims[1], fdims.dims[2], fdims.dims[3]},
{outStrides[0], outStrides[1], outStrides[2], outStrides[3]},
{outStrides.dims[0], outStrides.dims[1], outStrides.dims[2],
outStrides.dims[3]},
0}};
evalNodes(output, first.getNode().get());
}
Expand Down Expand Up @@ -146,19 +154,51 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {
}
} else {
// Write the result directly in the out array
const dim_t *outStrides{out.strides().dims};
const dim4 &outStrides{out.strides()};
Param<T> output{
out.get(),
{{sdims.dims[0], sdims.dims[1], sdims.dims[2], sdims.dims[3]},
{outStrides[0], outStrides[1], outStrides[2], outStrides[3]},
fdims.dims[jdim] * outStrides[jdim]}};
{outStrides.dims[0], outStrides.dims[1], outStrides.dims[2],
outStrides.dims[3]},
fdims.dims[jdim] * outStrides.dims[jdim]}};
evalNodes(output, second.getNode().get());
}
return out;
}

template<typename T>
void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
const dim_t n_arrays = inputs.size();
if (n_arrays == 0) return;

// avoid buffer overflow
const dim4 &odims{out.dims()};
const dim4 &fdims{inputs[0].dims()};
// All dimensions of inputs needs to be equal except for the join
// dimension
assert(std::all_of(inputs.begin(), inputs.end(),
[jdim, &fdims](const Array<T> &in) {
bool eq{true};
for (int i = 0; i < 4; ++i) {
if (i != jdim) {
eq &= fdims.dims[i] == in.dims().dims[i];
};
};
return eq;
}));
// All dimensions of out needs to cover all input dimensions
assert(
(odims.dims[0] >= fdims.dims[0]) && (odims.dims[1] >= fdims.dims[1]) &&
(odims.dims[2] >= fdims.dims[2]) && (odims.dims[3] >= fdims.dims[3]));
// The join dimension of out needs to be larger than the
// sum of all input join dimensions
assert(odims.dims[jdim] >=
std::accumulate(inputs.begin(), inputs.end(), 0,
[jdim](dim_t dim, const Array<T> &in) {
return dim += in.dims().dims[jdim];
}));
assert(out.strides().dims[0] == 1);

class eval {
public:
vector<Param<T>> outputs;
Expand All @@ -167,7 +207,7 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
vector<const Array<T> *> ins;
};
std::map<dim_t, eval> evals;
const dim_t *ostrides{out.strides().dims};
const dim4 &ostrides{out.strides()};
const size_t L2CacheSize{getL2CacheSize(oneapi::getDevice())};

// topspeed is achieved when byte size(in+out) ~= L2CacheSize
Expand All @@ -188,20 +228,21 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
// Group all arrays according to size
dim_t outOffset{0};
for (const Array<T> &iArray : inputs) {
const dim_t *idims{iArray.dims().dims};
const dim4 &idims{iArray.dims()};
eval &e{evals[idims[jdim]]};
const Param output{
out.get(),
{{idims[0], idims[1], idims[2], idims[3]},
{ostrides[0], ostrides[1], ostrides[2], ostrides[3]},
{{idims.dims[0], idims.dims[1], idims.dims[2], idims.dims[3]},
{ostrides.dims[0], ostrides.dims[1], ostrides.dims[2],
ostrides.dims[3]},
outOffset}};
e.outputs.push_back(output);
// Extend life of the returned node by saving the corresponding
// shared_ptr
e.nodePtrs.emplace_back(iArray.getNode());
e.nodes.push_back(e.nodePtrs.back().get());
e.ins.push_back(&iArray);
outOffset += idims[jdim] * ostrides[jdim];
outOffset += idims.dims[jdim] * ostrides.dims[jdim];
}

for (auto &eval : evals) {
Expand All @@ -215,7 +256,12 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
auto outputIt{begin(s.outputs)};
for (const Array<T> *in : s.ins) {
if (in->isReady()) {
if (1LL + jdim >= in->ndims() && in->isLinear()) {
const dim_t *istrides{in->strides().dims};
bool lin = in->isLinear() & (ostrides.dims[0] == 1);
for (int i{1}; i < in->ndims(); ++i) {
lin &= (ostrides.dims[i] == istrides[i]);
}
if (lin) {
getQueue().submit([&](sycl::handler &h) {
sycl::range sz(in->elements());
sycl::id src_offset(in->getOffset());
Expand Down
Loading