Skip to content

Commit a8c647d

Browse files
committed
Add support for cluster prediction and update example
1 parent 3609273 commit a8c647d

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

lib/node_modules/@stdlib/ml/incr/kmeans/examples/index.js

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ var discreteUniform = require( '@stdlib/random/base/discrete-uniform' );
2424
var normal = require( '@stdlib/random/base/normal' ).factory;
2525
var ndarray = require( '@stdlib/ndarray/ctor' );
2626
var Float64Array = require( '@stdlib/array/float64' );
27+
var Int8Array = require( '@stdlib/array/int8' );
2728
var incrkmeans = require( './../lib' );
2829

2930
var clusters;
@@ -37,6 +38,9 @@ var N;
3738
var k;
3839
var c;
3940
var v;
41+
var m;
42+
var p;
43+
var r;
4044
var X;
4145
var Y;
4246
var x;
@@ -77,10 +81,18 @@ for ( i = 0; i < k; i++ ) {
7781
randn.set( i, 1, normal( clusters.get( i, 2 ), clusters.get( i, 3 ) ) );
7882
}
7983

80-
// Create a 2-d vector for storing simulated data:
84+
// Create a vector for storing simulated data:
8185
ctor = ndarray( 'float64', 1 );
8286
v = ctor( new Float64Array( 2 ), [ 2 ], [ 1 ], 0, 'row-major' );
8387

88+
// Wrap the vector in a matrix for generating cluster predictions:
89+
ctor = ndarray( 'float64', 2 );
90+
m = ctor( v.data, [ 1, 2 ], [ 2, 1 ], 0, 'row-major' );
91+
92+
// Create a vector for storing cluster predictions:
93+
ctor = ndarray( 'int8', 1 );
94+
p = ctor( new Int8Array( 1 ), [ 1 ], [ 1 ], 0, 'row-major' );
95+
8496
// Simulate data points and incrementally perform k-means clustering...
8597
totals = [ 0, 0, 0, 0, 0 ];
8698
X = [];
@@ -99,6 +111,11 @@ for ( i = 0; i < N; i++ ) {
99111
v.set( 1, y );
100112
Y.push( y );
101113

114+
// Generate a cluster prediction:
115+
r = acc.predict( p, m );
116+
if ( r ) {
117+
console.log( 'Data point: (%d, %d). Prediction: %d.', x.toFixed( 3 ), y.toFixed( 3 ), r.get( 0 )+1 );
118+
}
102119
// Update the accumulator:
103120
results = acc( v );
104121
}

lib/node_modules/@stdlib/ml/incr/kmeans/lib/main.js

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,24 +233,69 @@ function incrkmeans( K, ndims, options ) {
233233
* Computes data point distances to centroids and returns centroid assignment predictions.
234234
*
235235
* @private
236+
* @param {ndarray} [out] - output vector for storing centroid assignment predictions
236237
* @param {ndarray} X - matrix containing data points (`n x d`, where `n` is the number of data points and `d` is the number of dimensions)
238+
* @throws {TypeError} output argument must be a vector
237239
* @throws {TypeError} must provide a matrix
240+
* @throws {Error} vector length must match number of data points
238241
* @throws {Error} number of matrix columns must match centroid dimensions
239-
* @returns {ndarray} vector containing centroid (index) predictions
242+
* @returns {(ndarray|null)} vector containing centroid (index) predictions or null
240243
*/
241-
function predict( X ) {
242-
var out;
243-
if ( !isMatrixLike( X ) ) {
244-
throw new TypeError( 'invalid argument. Must provide a 2-dimensional ndarray. Value: `' + X + '`.' );
244+
function predict( out, X ) {
245+
var xbuf;
246+
var cbuf;
247+
var npts;
248+
var sx1;
249+
var sx2;
250+
var sc;
251+
var ox;
252+
var x;
253+
var o;
254+
var c;
255+
var i;
256+
if ( arguments.length > 1 ) {
257+
if ( !isVectorLike( out ) ) {
258+
throw new TypeError( 'invalid argument. Output argument must be a 1-dimensional ndarray. Value: `' + out + '`.' );
259+
}
260+
o = out;
261+
x = X;
262+
} else {
263+
x = out;
264+
}
265+
if ( !isMatrixLike( x ) ) {
266+
throw new TypeError( 'invalid argument. Must provide a 2-dimensional ndarray. Value: `' + x + '`.' );
245267
}
246-
if ( X.shape[ 1 ] !== ndims ) {
247-
throw new Error( 'invalid input argument. Number of matrix columns must match centroid dimensions. Expected: ' + ndims + '. Actual: ' + X.shape[ 1 ] + '.' );
268+
if ( x.shape[ 1 ] !== ndims ) {
269+
throw new Error( 'invalid input argument. Number of matrix columns must match centroid dimensions. Expected: ' + ndims + '. Actual: ' + x.shape[ 1 ] + '.' );
248270
}
249-
out = createVector( out.shape[ 0 ], false ); // high-level
271+
if ( o && o.length !== x.shape[ 0 ] ) {
272+
throw new Error( 'invalid input argument. Output vector length must match the number of data points. Expected: ' + x.shape[ 0 ] + '. Actual: ' + o.length + '.' );
273+
} else {
274+
o = createVector( x.shape[ 0 ], false ); // high-level
275+
}
276+
if ( init ) {
277+
return null;
278+
}
279+
cbuf = centroids.data;
280+
sc = centroids.strides[ 0 ];
250281

251-
// TODO: (prediction) compute distances and assign data points to centroids
282+
xbuf = x.data;
283+
npts = x.shape[ 0 ];
284+
sx1 = x.strides[ 0 ];
285+
sx2 = x.strides[ 1 ];
286+
ox = x.offset;
252287

253-
return out;
288+
// For each data point, find the closest centroid...
289+
for ( i = 0; i < npts; i++ ) {
290+
c = closestCentroid( dist, k, ndims, cbuf, sc, 0, xbuf, sx2, ox ); // Magic number `0` for offset as we know that the matrix view begins at the first buffer element
291+
292+
// Update the output vector:
293+
o.set( i, c );
294+
295+
// Compute the data point buffer index offset to point to the next data point:
296+
ox += sx1;
297+
}
298+
return o;
254299
}
255300
}
256301

0 commit comments

Comments
 (0)