-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathvector.lua
More file actions
191 lines (151 loc) · 3.84 KB
/
vector.lua
File metadata and controls
191 lines (151 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
--- Vector functions.
-- Standard library imports --
local assert = assert
local min = math.min
local select = select
local type = type
-- Modules --
local array = require("impl.array")
-- Imports --
local Call = array.Call
local CallWrap = array.CallWrap
local GetLib = array.GetLib
local HandleDim = array.HandleDim
local IsArray = array.IsArray
local ToType = array.ToType
-- Exports --
local M = {}
-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/reduce.cpp
-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/where.cpp
--
local function Bool (value)
if value ~= nil then
return not not value
else
return true
end
end
--
local function Funcs (name, prefix)
name = (prefix or "af_") .. name
return name, name .. "_all"
end
--
local function Reduce (name)
local func, func_all = Funcs(name)
return function(in_arr, dim)
local rtype
if type(in_arr) == "string" then
rtype, in_arr = in_arr, dim
end
if rtype then
local r, i = Call(func_all, in_arr:get())
return ToType(rtype, r, i)
else
return HandleDim(func, in_arr, dim)
end
end
end
--
local function ReduceMaxMin (name)
local func, func_all = Funcs(name)
local ifunc, ifunc_all = Funcs(name, "af_i")
local arith = name .. "of"
return function(a, b, c, d)
if type(a) == "string" then -- a: type, b: in_arr[, c: "get_index"]
if c == "get_index" then -- TODO: This is ugly and doesn't resemble the C++ interface... maybe a table as first argument, as compromise?
local r, i, index = Call(ifunc_all, b:get())
return ToType(a, r, i), index
else
return ToType(a, Call(func_all, b:get()))
end
elseif IsArray(c) then -- a: val, b: idx, c: arr, d: dim
local out, idx = HandleDim(ifunc, c, d, "no_wrap")
a:set(out)
b:set(idx)
elseif not b or c == "dim" then -- a: arr, b: dim[, c: "dim"] (TODO: Again, ugly but no obvious alternative... IsConstant()?)
return HandleDim(func, a, b)
else -- a: lhs, b: rhs
return GetLib()[arith](a, b)
end
end
end
--
local function ReduceNaN (name)
local func, func_all = Funcs(name)
local func_nan, func_nan_all = Funcs(name .. "_nan")
return function(in_arr, dim, nanval)
local rtype
if type(in_arr) == "string" then
rtype, in_arr = in_arr, dim
end
if rtype then
local r, i
if nanval then
r, i = Call(func_nan_all, in_arr:get(), nanval)
else
r, i = Call(func_all, in_arr:get())
end
return ToType(rtype, r, i)
else
if nanval then
return CallWrap(func_nan, in_arr:get(), dim, nanval)
else
return HandleDim(func, in_arr, dim)
end
end
end
end
-- TODO: lost in macroland :P (probably missing some stuff)
--
local AllTrue, AnyTrue = Reduce("all_true"), Reduce("any_true")
--
function M.Add (into)
for k, v in pairs{
--
alltrue = AllTrue, allTrue = AllTrue,
--
anytrue = AnyTrue, anyTrue = AnyTrue,
--
count = Reduce("count"),
--
diff1 = function(in_arr, dim)
return CallWrap("af_diff1", in_arr:get(), dim)
end,
--
diff2 = function(in_arr, dim)
return CallWrap("af_diff2", in_arr:get(), dim)
end,
--
max = ReduceMaxMin("max"),
--
min = ReduceMaxMin("min"),
--
product = ReduceNaN("product"),
--
sort = function(a, b, c, d, e, f)
if IsArray(d) then -- four arrays
local keys, values = Call("af_sort_by_key", c:get(), d:get(), e or 0, Bool(f))
a:set(keys)
b:set(values)
elseif IsArray(c) then -- three arrays
local arr, indices = Call("af_sort_index", c:get(), d or 0, Bool(e))
a:set(arr)
b:set(indices)
else -- one array
return CallWrap("af_sort", a:get(), b or 0, Bool(c))
end
end,
--
sum = ReduceNaN("sum"),
--
where = function(in_arr)
assert(not GetLib().gforGet(), "WHERE can not be used inside GFOR") -- TODO: AF_ERR_RUNTIME);
return CallWrap("af_where", in_arr:get())
end
} do
into[k] = v
end
end
-- Export the module.
return M