Skip to content

Commit a1b67de

Browse files
committed
Add SACA from imagej-ops dev branch
This commit adds Shulei and Ellen's SACA code from imagej/imagej-ops coloc-multithread-saca branch from commit b12ce2904945f7cf14db105e86f917d113279b63. This is right before Ellen commited a WIP first attempt to multithread SACA. While it works there are notes indicating that the output is still slow and does not conform to Shulei's output from the original R package. Given this, I felt it best to start the SciJava Ops port from the last known working state and then make progress towards resolving the Op performance and output results in the SciJava Ops framework.
1 parent aff6770 commit a1b67de

File tree

3 files changed

+590
-0
lines changed

3 files changed

+590
-0
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
2+
package net.imagej.ops.coloc.saca;
3+
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.Random;
7+
import java.util.function.Function;
8+
9+
import net.imglib2.Cursor;
10+
import net.imglib2.Localizable;
11+
import net.imglib2.RandomAccess;
12+
import net.imglib2.RandomAccessibleInterval;
13+
import net.imglib2.loops.LoopBuilder;
14+
import net.imglib2.type.numeric.RealType;
15+
import net.imglib2.type.numeric.real.DoubleType;
16+
import net.imglib2.util.Intervals;
17+
import net.imglib2.util.Localizables;
18+
import net.imglib2.util.Util;
19+
import net.imglib2.view.IntervalView;
20+
import net.imglib2.view.Views;
21+
import net.imglib2.view.composite.CompositeIntervalView;
22+
import net.imglib2.view.composite.GenericComposite;
23+
24+
/**
25+
* Adapted from Shulei's original Java code for AdaptiveSmoothedKendallTau from
26+
* his RKColocal R package.
27+
* (https://github.com/lakerwsl/RKColocal/blob/master/RKColocal_0.0.1.0000.tar.gz)
28+
*
29+
* @author Shulei Wang
30+
* @author Curtis Rueden
31+
* @author Ellen T Arena
32+
*/
33+
public final class AdaptiveSmoothedKendallTau {
34+
35+
private AdaptiveSmoothedKendallTau() {}
36+
37+
public static <I extends RealType<I>, O extends RealType<O>> void execute(
38+
final RandomAccessibleInterval<I> image1,
39+
final RandomAccessibleInterval<I> image2, final I thres1, final I thres2,
40+
final RandomAccessibleInterval<O> result, final long seed)
41+
{
42+
execute(image1, image2, thres1, thres2, new DoubleType(), result, seed);
43+
}
44+
45+
public static <I extends RealType<I>, T extends RealType<T>, O extends RealType<O>> void execute(
46+
final RandomAccessibleInterval<I> image1,
47+
final RandomAccessibleInterval<I> image2, final I thres1, final I thres2,
48+
final T intermediate, final RandomAccessibleInterval<O> result,
49+
final long seed)
50+
{
51+
final Function<RandomAccessibleInterval<I>, RandomAccessibleInterval<T>> factory =
52+
img -> Util.getSuitableImgFactory(img, intermediate).create(img);
53+
execute(image1, image2, thres1, thres2, factory, result, seed);
54+
}
55+
56+
public static <I extends RealType<I>, T extends RealType<T>, O extends RealType<O>> void execute(
57+
final RandomAccessibleInterval<I> image1,
58+
final RandomAccessibleInterval<I> image2, final I thres1, final I thres2,
59+
Function<RandomAccessibleInterval<I>, RandomAccessibleInterval<T>> factory,
60+
final RandomAccessibleInterval<O> result, final long seed)
61+
{
62+
final long nr = image1.dimension(1);
63+
final long nc = image1.dimension(0);
64+
final RandomAccessibleInterval<T> oldtau = factory.apply(image1);
65+
final RandomAccessibleInterval<T> newtau = factory.apply(image1);
66+
final RandomAccessibleInterval<T> oldsqrtN = factory.apply(image1);
67+
final RandomAccessibleInterval<T> newsqrtN = factory.apply(image1);
68+
final List<RandomAccessibleInterval<T>> stop = new ArrayList<>();
69+
for (int s = 0; s < 3; s++)
70+
stop.add(factory.apply(image1));
71+
final double Dn = Math.sqrt(Math.log(nr * nc)) * 2;
72+
final int TU = 15;
73+
final int TL = 8;
74+
final double Lambda = Dn;
75+
76+
LoopBuilder.setImages(oldsqrtN).forEachPixel(t -> t.setOne());
77+
78+
double size = 1;
79+
final double stepsize = 1.15; // empirically the best, but could have users
80+
// set
81+
int intSize;
82+
boolean IsCheck = false;
83+
84+
final Random rng = new Random(seed);
85+
86+
for (int s = 0; s < TU; s++) {
87+
intSize = (int) Math.floor(size);
88+
singleiteration(image1, image2, thres1, thres2, stop, oldtau, oldsqrtN,
89+
newtau, newsqrtN, result, Lambda, Dn, intSize, IsCheck, rng);
90+
size *= stepsize;
91+
if (s == TL) {
92+
IsCheck = true;
93+
LoopBuilder.setImages(stop.get(1), stop.get(2), newtau, newsqrtN)
94+
.forEachPixel((ts1, ts2, tTau, tSqrtN) -> {
95+
ts1.set(tTau);
96+
ts2.set(tSqrtN);
97+
});
98+
}
99+
}
100+
}
101+
102+
private static <I extends RealType<I>, T extends RealType<T>, O extends RealType<O>>
103+
void singleiteration(final RandomAccessibleInterval<I> image1,
104+
final RandomAccessibleInterval<I> image2, final I thres1, final I thres2,
105+
final List<RandomAccessibleInterval<T>> stop,
106+
final RandomAccessibleInterval<T> oldtau,
107+
final RandomAccessibleInterval<T> oldsqrtN,
108+
final RandomAccessibleInterval<T> newtau,
109+
final RandomAccessibleInterval<T> newsqrtN,
110+
final RandomAccessibleInterval<O> result, final double Lambda,
111+
final double Dn, final int Bsize, final boolean isCheck, final Random rng)
112+
{
113+
final double[][] kernel = kernelGenerate(Bsize);
114+
115+
final long[] rowrange = new long[4];
116+
final long[] colrange = new long[4];
117+
final int totnum = (2 * Bsize + 1) * (2 * Bsize + 1);
118+
final double[] LocX = new double[totnum];
119+
final double[] LocY = new double[totnum];
120+
final double[] LocW = new double[totnum];
121+
final double[][] combinedData = new double[totnum][3];
122+
final int[] rankedindex = new int[totnum];
123+
final double[] rankedw = new double[totnum];
124+
final int[] index1 = new int[totnum];
125+
final int[] index2 = new int[totnum];
126+
final double[] w1 = new double[totnum];
127+
final double[] w2 = new double[totnum];
128+
final double[] cumw = new double[totnum];
129+
130+
RandomAccessibleInterval<T> workingImageStack = Views.stack(oldtau, newtau, oldsqrtN, newsqrtN, stop.get(0), stop.get(1), stop.get(2));
131+
CompositeIntervalView<T, ? extends GenericComposite<T>> workingImage =
132+
Views.collapse(workingImageStack);
133+
134+
IntervalView<Localizable> positions = Views.interval( Localizables.randomAccessible(result.numDimensions() ), result );
135+
final long nr = result.dimension(1);
136+
final long nc = result.dimension(0);
137+
final RandomAccess<I> gdImage1 = image1.randomAccess();
138+
final RandomAccess<I> gdImage2 = image2.randomAccess();
139+
final RandomAccess<T> gdTau = oldtau.randomAccess();
140+
final RandomAccess<T> gdSqrtN = oldsqrtN.randomAccess();
141+
LoopBuilder.setImages(positions, result, workingImage).forEachPixel((pos, resPixel, workingPixel) -> {
142+
T oldtauPix = workingPixel.get(0);
143+
T newtauPix = workingPixel.get(1);
144+
T oldsqrtNPix = workingPixel.get(2);
145+
T newsqrtNPix = workingPixel.get(3);
146+
T stop0Pix = workingPixel.get(4);
147+
T stop1Pix = workingPixel.get(5);
148+
T stop2Pix = workingPixel.get(6);
149+
final long row = pos.getLongPosition(1);
150+
updateRange(row, Bsize, nr, rowrange);
151+
if (isCheck) {
152+
if (stop0Pix.getRealDouble() != 0) {
153+
return;
154+
}
155+
}
156+
final long col = pos.getLongPosition(0);
157+
updateRange(col, Bsize, nc, colrange);
158+
getData(Dn, kernel, gdImage1, gdImage2, gdTau, gdSqrtN, LocX, LocY, LocW,
159+
rowrange, colrange, totnum);
160+
newsqrtNPix.setReal(Math.sqrt(NTau(thres1, thres2, LocW, LocX,
161+
LocY)));
162+
if (newsqrtNPix.getRealDouble() <= 0) {
163+
newtauPix.setZero();
164+
resPixel.setZero();
165+
}
166+
else {
167+
final double tau = WtKendallTau.calculate(LocX, LocY, LocW, combinedData,
168+
rankedindex, rankedw, index1, index2, w1, w2, cumw, rng);
169+
newtauPix.setReal(tau);
170+
resPixel.setReal(tau * newsqrtNPix.getRealDouble() * 1.5);
171+
}
172+
173+
if (isCheck) {
174+
final double taudiff = Math.abs(stop1Pix.getRealDouble() - newtauPix
175+
.getRealDouble()) * stop2Pix.getRealDouble();
176+
if (taudiff > Lambda) {
177+
stop0Pix.setOne();
178+
newtauPix.set(oldtauPix);
179+
newsqrtNPix.set(oldsqrtNPix);
180+
}
181+
}
182+
});
183+
184+
// TODO: instead of copying pixels here, swap oldTau and newTau every time.
185+
// :-)
186+
LoopBuilder.setImages(oldtau, newtau, oldsqrtN, newsqrtN).forEachPixel((
187+
tOldTau, tNewTau, tOldSqrtN, tNewSqrtN) -> {
188+
tOldTau.set(tNewTau);
189+
tOldSqrtN.set(tNewSqrtN);
190+
});
191+
192+
}
193+
194+
private static <I extends RealType<I>, T extends RealType<T>> void getData(
195+
final double Dn, final double[][] w, final RandomAccess<I> i1RA,
196+
final RandomAccess<I> i2RA, final RandomAccess<T> tau,
197+
final RandomAccess<T> sqrtN, final double[] sx, final double[] sy,
198+
final double[] sw, final long[] rowrange, final long[] colrange,
199+
final int totnum)
200+
{
201+
// TODO: Decide if this cast is OK.
202+
int kernelk = (int) (rowrange[0] - rowrange[2] + rowrange[3]);
203+
int kernell;
204+
int index = 0;
205+
double taudiffabs;
206+
207+
sqrtN.setPosition(colrange[2], 0);
208+
sqrtN.setPosition(rowrange[2], 1);
209+
final double sqrtNValue = sqrtN.get().getRealDouble();
210+
211+
for (long k = rowrange[0]; k <= rowrange[1]; k++) {
212+
i1RA.setPosition(k, 1);
213+
i2RA.setPosition(k, 1);
214+
sqrtN.setPosition(k, 1);
215+
// TODO: Double check cast.
216+
kernell = (int) (colrange[0] - colrange[2] + colrange[3]);
217+
for (long l = colrange[0]; l <= colrange[1]; l++) {
218+
i1RA.setPosition(l, 0);
219+
i2RA.setPosition(l, 0);
220+
sqrtN.setPosition(l, 0);
221+
sx[index] = i1RA.get().getRealDouble();
222+
sy[index] = i2RA.get().getRealDouble();
223+
sw[index] = w[kernelk][kernell];
224+
225+
tau.setPosition(l, 0);
226+
tau.setPosition(k, 1);
227+
final double tau1 = tau.get().getRealDouble();
228+
229+
tau.setPosition(colrange[2], 0);
230+
tau.setPosition(rowrange[2], 1);
231+
final double tau2 = tau.get().getRealDouble();
232+
233+
taudiffabs = Math.abs(tau1 - tau2) * sqrtNValue;
234+
taudiffabs = taudiffabs / Dn;
235+
if (taudiffabs < 1) sw[index] = sw[index] * (1 - taudiffabs) * (1 -
236+
taudiffabs);
237+
else sw[index] = sw[index] * 0;
238+
kernell++;
239+
index++;
240+
}
241+
kernelk++;
242+
}
243+
while (index < totnum) {
244+
sx[index] = 0;
245+
sy[index] = 0;
246+
sw[index] = 0;
247+
index++;
248+
}
249+
}
250+
251+
private static void updateRange(final long location, final int radius,
252+
final long boundary, final long[] range)
253+
{
254+
range[0] = location - radius;
255+
if (range[0] < 0) range[0] = 0;
256+
range[1] = location + radius;
257+
if (range[1] >= boundary) range[1] = boundary - 1;
258+
range[2] = location;
259+
range[3] = radius;
260+
}
261+
262+
private static <I extends RealType<I>> double NTau(final I thres1,
263+
final I thres2, final double[] w, final double[] x, final double[] y)
264+
{
265+
double sumW = 0;
266+
double sumsqrtW = 0;
267+
double tempW;
268+
269+
for (int index = 0; index < w.length; index++) {
270+
if (x[index] < thres1.getRealDouble() || y[index] < thres2
271+
.getRealDouble()) w[index] = 0;
272+
tempW = w[index];
273+
sumW += tempW;
274+
tempW = tempW * w[index];
275+
sumsqrtW += tempW;
276+
}
277+
double NW;
278+
final double Denomi = sumW * sumW;
279+
if (Denomi <= 0) {
280+
NW = 0;
281+
}
282+
else {
283+
NW = Denomi / sumsqrtW;
284+
}
285+
return NW;
286+
}
287+
288+
private static double[][] kernelGenerate(final int size) {
289+
final int L = size * 2 + 1;
290+
final double[][] kernel = new double[L][L];
291+
final int center = size;
292+
double temp;
293+
final double Rsize = size * Math.sqrt(2.5);
294+
295+
for (int i = 0; i <= size; i++) {
296+
for (int j = 0; j <= size; j++) {
297+
temp = Math.sqrt(i * i + j * j) / Rsize;
298+
if (temp >= 1) temp = 0;
299+
else temp = 1 - temp;
300+
kernel[center + i][center + j] = temp;
301+
kernel[center - i][center + j] = temp;
302+
kernel[center + i][center - j] = temp;
303+
kernel[center - i][center - j] = temp;
304+
}
305+
}
306+
return kernel;
307+
}
308+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*-
2+
* #%L
3+
* ImageJ software for multidimensional image processing and analysis.
4+
* %%
5+
* Copyright (C) 2014 - 2018 ImageJ developers.
6+
* %%
7+
* Redistribution and use in source and binary forms, with or without
8+
* modification, are permitted provided that the following conditions are met:
9+
*
10+
* 1. Redistributions of source code must retain the above copyright notice,
11+
* this list of conditions and the following disclaimer.
12+
* 2. Redistributions in binary form must reproduce the above copyright notice,
13+
* this list of conditions and the following disclaimer in the documentation
14+
* and/or other materials provided with the distribution.
15+
*
16+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
20+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26+
* POSSIBILITY OF SUCH DAMAGE.
27+
* #L%
28+
*/
29+
30+
package net.imagej.ops.coloc.saca;
31+
32+
import net.imagej.ops.Ops;
33+
import net.imagej.ops.special.computer.AbstractBinaryComputerOp;
34+
import net.imglib2.RandomAccessibleInterval;
35+
import net.imglib2.histogram.Histogram1d;
36+
import net.imglib2.type.numeric.RealType;
37+
import net.imglib2.util.Intervals;
38+
import net.imglib2.view.Views;
39+
40+
import org.scijava.plugin.Parameter;
41+
import org.scijava.plugin.Plugin;
42+
43+
/**
44+
* This algorithm is adapted from Spatially Adaptive Colocalization Analysis
45+
* (SACA) by Wang et al (2019); computes thresholds using Otsu method.
46+
*
47+
* @param <I> Type of the input images
48+
* @param <O> Type of the output image
49+
*/
50+
@Plugin(type = Ops.Coloc.SACA.class)
51+
public class SACA<I extends RealType<I>, O extends RealType<O>> extends
52+
AbstractBinaryComputerOp<RandomAccessibleInterval<I>, RandomAccessibleInterval<I>, RandomAccessibleInterval<O>>
53+
implements Ops.Coloc.SACA
54+
{
55+
56+
@Parameter(required = false)
57+
private I thres1;
58+
59+
@Parameter(required = false)
60+
private I thres2;
61+
62+
@Parameter(required = false)
63+
private long seed = 0xdeadbeef;
64+
65+
@Override
66+
public void compute(final RandomAccessibleInterval<I> image1,
67+
final RandomAccessibleInterval<I> image2,
68+
final RandomAccessibleInterval<O> result)
69+
{
70+
71+
// check image sizes
72+
if (!(Intervals.equalDimensions(image1, image2))) {
73+
throw new IllegalArgumentException("Image dimensions do not match");
74+
}
75+
76+
// compute thresholds if necessary
77+
if (thres1 == null) thres1 = threshold(image1);
78+
if (thres2 == null) thres2 = threshold(image2);
79+
80+
AdaptiveSmoothedKendallTau.execute(image1, image2, thres1, thres2, result, seed);
81+
}
82+
83+
<V extends RealType<V>> V threshold(final RandomAccessibleInterval<V> image) {
84+
final Histogram1d<V> histogram = ops().image().histogram(Views.iterable(
85+
image));
86+
return ops().threshold().otsu(histogram);
87+
}
88+
}

0 commit comments

Comments
 (0)