1 module rmathd.hyper;
2 
3 public import rmathd.common;
4 public import rmathd.choose;
5 public import rmathd.binomial;
6 
7 
8 /* 
9 ** normal.d poisson.d exponential.d
10 ** dmd hyper.d common.d choose.d binomial.d && ./hyper
11 */
12 
13 
14 T dhyper(T: double)(T x, T r, T b, T n, int give_log)
15 {
16     T p, q, p1, p2, p3;
17     
18     mixin R_D__0!give_log;
19     mixin R_D__1!give_log;
20     if (isNaN(x) || isNaN(r) || isNaN(b) || isNaN(n))
21 	    return x + r + b + n;
22 
23     if (R_D_negInonint!T(r) || R_D_negInonint!T(b) || R_D_negInonint!T(n) || n > r + b)
24 	    return T.nan;
25     if(x < 0)
26     	return(R_D__0);
27     mixin (R_D_nonint_check!(x));// incl warning
28 
29     x = nearbyint(x);
30     r = nearbyint(r);
31     b = nearbyint(b);
32     n = nearbyint(n);
33 
34     if (n < x || r < x || n - x > b)
35     	return R_D__0;
36     if (n == 0)
37     	return((x == 0) ? R_D__1 : R_D__0);
38 
39     p = (cast(T)n)/(cast(T)(r + b));
40     q = (cast(T)(r + b - n))/(cast(T)(r + b));
41 
42     p1 = dbinom_raw!T(x,	r, p, q, give_log);
43     p2 = dbinom_raw!T(n - x, b, p, q, give_log);
44     p3 = dbinom_raw!T(n, r + b, p, q, give_log);
45 
46     return( (give_log) ? p1 + p2 - p3 : p1*p2/p3 );
47 }
48 
49 
50 static T pdhyper(T)(T x, T NR, T NB, T n, int log_p)
51 {
52 /*
53  * Calculate
54  *
55  *	    phyper (x, NR, NB, n, TRUE, FALSE)
56  *   [log]  ----------------------------------
57  *	       dhyper (x, NR, NB, n, FALSE)
58  *
59  * without actually calling phyper.  This assumes that
60  *
61  *     x * (NR + NB) <= n * NR
62  *
63  */
64     real sum = 0;
65     real term = 1;
66 
67     while (x > 0 && term >= DBL_EPSILON * sum) {
68 	    term *= x * (NB - n + x) / (n + 1 - x) / (NR + 1 - x);
69 	    sum += term;
70 	    x--;
71     }
72     
73     T ss = cast(T) sum;
74     return log_p ? log1p(ss) : 1 + ss;
75 }
76 
77 
78 /* FIXME: The old phyper() code was basically used in ./qhyper.c as well
79  * -----  We need to sync this again!
80 */
81 T phyper(T: double)(T x, T NR, T NB, T n, int lower_tail, int log_p)
82 {
83     /* Sample of  n balls from  NR red  and	 NB black ones;	 x are red */
84     
85     T d, pd;
86     
87     if(isNaN(x) || isNaN(NR) || isNaN(NB) || isNaN(n))
88 	    return x + NR + NB + n;
89     
90     x = floor (x + 1e-7);
91     NR = nearbyint(NR);
92     NB = nearbyint(NB);
93     n  = nearbyint(n);
94 
95     if (NR < 0 || NB < 0 || !isFinite(NR + NB) || n < 0 || n > NR + NB)
96 	    return T.nan;
97 
98     if (x * (NR + NB) > n * NR) {
99 	    /* Swap tails.	*/
100 	    T oldNB = NB;
101 	    NB = NR;
102 	    NR = oldNB;
103 	    x = n - x - 1;
104 	    lower_tail = !lower_tail;
105     }
106     
107     if (x < 0)
108 	    return R_DT_0!T(lower_tail, log_p);
109     if (x >= NR || x >= n)
110 	    return R_DT_1!T(lower_tail, log_p);
111     
112     d  = dhyper!T(x, NR, NB, n, log_p);
113     pd = pdhyper!T(x, NR, NB, n, log_p);
114     
115     return log_p ? R_DT_Log!T(d + pd, lower_tail) : R_D_Lval!T(d*pd, lower_tail);
116 }
117 
118 
119 T qhyper(T: double)(T p, T NR, T NB, T n, int lower_tail, int log_p)
120 {
121 /* This is basically the same code as  ./phyper.c  *used* to be --> FIXME! */
122     T N, xstart, xend, xr, xb, sum, term;
123     int small_N;
124     
125     if (isNaN(p) || isNaN(NR) || isNaN(NB) || isNaN(n))
126 	    return p + NR + NB + n;
127     
128     if(!isFinite(p) || !isFinite(NR) || !isFinite(NB) || !isFinite(n))
129 	    return T.nan;
130 
131     NR = nearbyint(NR);
132     NB = nearbyint(NB);
133     N = NR + NB;
134     n = nearbyint(n);
135     if (NR < 0 || NB < 0 || n < 0 || n > N)
136 	    return T.nan;
137 
138     /* Goal:  Find  xr (= #{red balls in sample}) such that
139      *   phyper(xr,  NR,NB, n) >= p > phyper(xr - 1,  NR,NB, n)
140      */
141 
142     xstart = fmax2!T(0, n - NB);
143     xend = fmin2!T(n, NR);
144 
145     mixin (R_Q_P01_boundaries!(p, xstart, xend));
146 
147     xr = xstart;
148     xb = n - xr;/* always ( = #{black balls in sample} ) */
149 
150     small_N = (N < 1000); /* won't have underflow in product below */
151     /* if N is small,  term := product.ratio( bin.coef );
152        otherwise work with its logarithm to protect against underflow */
153     term = lfastchoose!T(NR, xr) + lfastchoose!T(NB, xb) - lfastchoose!T(N, n);
154     if(small_N)
155     	term = exp(term);
156     NR -= xr;
157     NB -= xb;
158 
159     mixin R_DT_qIv!p;
160     if(!lower_tail || log_p) {
161 	    p = R_DT_qIv;
162     }
163     p *= 1 - 1000*DBL_EPSILON; /* was 64, but failed on FreeBSD sometimes */
164     sum = small_N ? term : exp(term);
165 
166     while(sum < p && xr < xend) {
167 	    xr++;
168 	    NB++;
169 	    if (small_N) term *= (NR / xr) * (xb / NB);
170 	    else term += log((NR / xr) * (xb / NB));
171 	    sum += small_N ? term : exp(term);
172 	    xb--;
173 	    NR--;
174     }
175     return xr;
176 }
177 
178 
179 static T afc(T)(int i)
180 {
181     // If (i > 7), use Stirling's approximation, otherwise use table lookup.
182     const static T[8] al = [
183 	    0.0,/*ln(0!)=ln(1)*/
184 	    0.0,/*ln(1!)=ln(1)*/
185 	    0.69314718055994530941723212145817,/*ln(2) */
186 	    1.79175946922805500081247735838070,/*ln(6) */
187 	    3.17805383034794561964694160129705,/*ln(24)*/
188 	    4.78749174278204599424770093452324,
189 	    6.57925121201010099506017829290394,
190 	    8.52516136106541430016553103634712
191 	    /* 10.60460290274525022841722740072165, approx. value below =
192 	       10.6046028788027; rel.error = 2.26 10^{-9}
193         
194 	      FIXME: Use constants and if(n > ..) decisions from ./stirlerr.c
195 	      -----  will be even *faster* for n > 500 (or so)
196 	    */
197     ];
198 
199     if (i < 0) {
200 	    //MATHLIB_WARNING(("rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n"), i);
201 	    return -1; // unreached
202     }
203     if (i <= 7)
204 	    return al[i];
205     // else i >= 8 :
206     T di = i, i2 = di*di;
207     return (di + 0.5) * log(di) - di + M_LN_SQRT_2PI + (0.0833333333333333 - 0.00277777777777778 / i2) / di;
208 }
209 
210 
211 //     rhyper(NR, NB, n) -- NR 'red', NB 'blue', n drawn, how many are 'red'
212 T rhyper(T: double)(T nn1in, T nn2in, T kkin)
213 {
214     /* extern double afc(int); */
215 
216     int nn1, nn2, kk;
217     int ix; // return value (coerced to double at the very end)
218     //Rboolean
219     int setup1, setup2;
220 
221     /* These should become 'thread_local globals' : */
222     static int ks = -1, n1s = -1, n2s = -1;
223     static int m, minjx, maxjx;
224     static int k, n1, n2; // <- not allowing larger integer par
225     static T tn;
226 
227     // II :
228     static T w;
229     // III:
230     static T a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;
231 
232     /* check parameter validity */
233 
234     if(!isFinite(nn1in) || !isFinite(nn2in) || !isFinite(kkin))
235 	    return T.nan;
236 
237     nn1in = nearbyint(nn1in);
238     nn2in = nearbyint(nn2in);
239     kkin  = nearbyint(kkin);
240 
241     if (nn1in < 0 || nn2in < 0 || kkin < 0 || kkin > nn1in + nn2in)
242 	    return T.nan;
243     if (nn1in >= INT_MAX || nn2in >= INT_MAX || kkin >= INT_MAX) {
244 	    /* large n -- evade integer overflow (and inappropriate algorithms)
245 	       -------- */
246             // FIXME: Much faster to give rbinom() approx when appropriate; -> see Kuensch(1989)
247 	    // Johnson, Kotz,.. p.258 (top) mention the *four* different binomial approximations
248 	    if(kkin == 1.) { // Bernoulli
249 	        return rbinom!T(kkin, nn1in / (nn1in + nn2in));
250 	    }
251 	    // Slow, but safe: return  F^{-1}(U)  where F(.) = phyper(.) and  U ~ U[0,1]
252 	    return qhyper!T(unif_rand!T(), nn1in, nn2in, kkin, 0, 0);
253     }
254     nn1 = cast(int)nn1in;
255     nn2 = cast(int)nn2in;
256     kk  = cast(int)kkin;
257 
258     /* if new parameter values, initialize */
259     if (nn1 != n1s || nn2 != n2s) {
260 	    setup1 = 1;	setup2 = 1;
261     } else if (kk != ks) {
262 	    setup1 = 0;	setup2 = 1;
263     } else {
264 	    setup1 = 0;	setup2 = 0;
265     }
266     if (setup1) {
267 	    n1s = nn1;
268 	    n2s = nn2;
269 	    tn = nn1 + nn2;
270 	    if (nn1 <= nn2) {
271 	        n1 = nn1;
272 	        n2 = nn2;
273 	    } else {
274 	        n1 = nn2;
275 	        n2 = nn1;
276 	    }
277     }
278     if (setup2) {
279 	    ks = kk;
280 	    if (kk + kk >= tn) {
281 	        k = cast(int)(tn - kk);
282 	    } else {
283 	        k = kk;
284 	    }
285     }
286     if (setup1 || setup2) {
287 	    m = cast(int) ((k + 1.) * (n1 + 1.) / (tn + 2.));
288 	    minjx = imax2(0, k - n2);
289 	    maxjx = imin2(n1, k);
290         //#ifdef DEBUG_rhyper
291         //	REprintf("rhyper(nn1=%d, nn2=%d, kk=%d), setup: floor(mean)= m=%d, jx in (%d..%d)\n",
292         //		 nn1, nn2, kk, m, minjx, maxjx);
293         //#endif
294     }
295     /* generate random variate --- Three basic cases */
296 
297     if (minjx == maxjx) { /* I: degenerate distribution ---------------- */
298         //#ifdef DEBUG_rhyper
299         //	REprintf("rhyper(), branch I (degenerate)\n");
300         //#endif
301 	    ix = maxjx;
302 	    goto L_finis; // return appropriate variate
303         
304     } else if (m - minjx < 10) { // II: (Scaled) algorithm HIN (inverse transformation) ----
305 	const static T scale = 1e25; // scaling factor against (early) underflow
306 	const static T con = 57.5646273248511421;
307 					  // 25*log(10) = log(scale) { <==> exp(con) == scale }
308 	if (setup1 || setup2) {
309 	    T lw; // log(w);  w = exp(lw) * scale = exp(lw + log(scale)) = exp(lw + con)
310 	    if (k < n2) {
311 		    lw = afc!T(n2) + afc!T(n1 + n2 - k) - afc!T(n2 - k) - afc!T(n1 + n2);
312 	    } else {
313 		    lw = afc!T(n1) + afc!T(     k     ) - afc!T(k - n2) - afc!T(n1 + n2);
314 	    }
315 	    w = exp(lw + con);
316 	}
317 	T p, u;
318     //#ifdef DEBUG_rhyper
319     //	REprintf("rhyper(), branch II; w = %g > 0\n", w);
320     //#endif
321       L10:
322 	p = w;
323 	ix = minjx;
324 	u = unif_rand!T() * scale;
325     //#ifdef DEBUG_rhyper
326     //	REprintf("  _new_ u = %g\n", u);
327     //#endif
328 	    while (u > p) {
329 	        u -= p;
330 	        p *= (cast(T) n1 - ix) * (k - ix);
331 	        ix++;
332 	        p = p / ix / (n2 - k + ix);
333         //#ifdef DEBUG_rhyper
334         //	    REprintf("       ix=%3d, u=%11g, p=%20.14g (u-p=%g)\n", ix, u, p, u-p);
335         //#endif
336 	        if (ix > maxjx)
337 	    	    goto L10;
338 	        // FIXME  if(p == 0.)  we also "have lost"  => goto L10
339 	    }
340     } else { /* III : H2PE Algorithm --------------------------------------- */
341 
342 	T u,v;
343 
344 	if (setup1 || setup2) {
345 	    s = sqrt((tn - k) * k * n1 * n2 / (tn - 1) / tn / tn);
346 
347 	    /* remark: d is defined in reference without int. */
348 	    /* the truncation centers the cell boundaries at 0.5 */
349 
350 	    d = cast(int) (1.5 * s) + .5;
351 	    xl = m - d + .5;
352 	    xr = m + d + .5;
353 	    a = afc!T(m) + afc!T(n1 - m) + afc!T(k - m) + afc!T(n2 - k + m);
354 	    kl = exp(a - afc!T(cast(int) (xl)) - afc!T(cast(int) (n1 - xl))
355 		     - afc!T(cast(int) (k - xl))
356 		     - afc!T(cast(int) (n2 - k + xl)));
357 	    kr = exp(a - afc!T(cast(int) (xr - 1))
358 		     - afc!T(cast(int) (n1 - xr + 1))
359 		     - afc!T(cast(int) (k - xr + 1))
360 		     - afc!T(cast(int) (n2 - k + xr - 1)));
361 	    lamdl = -log(xl * (n2 - k + xl) / (n1 - xl + 1) / (k - xl + 1));
362 	    lamdr = -log((n1 - xr + 1) * (k - xr + 1) / xr / (n2 - k + xr));
363 	    p1 = d + d;
364 	    p2 = p1 + kl / lamdl;
365 	    p3 = p2 + kr / lamdr;
366 	}
367     //#ifdef DEBUG_rhyper
368     //	REprintf("rhyper(), branch III {accept/reject}: (xl,xr)= (%g,%g); (lamdl,lamdr)= (%g,%g)\n",
369     //		 xl, xr, lamdl,lamdr);
370     //	REprintf("-------- p123= c(%g,%g,%g)\n", p1,p2, p3);
371     //#endif
372 	int n_uv = 0;
373       L30:
374 	u = unif_rand!T() * p3;
375 	v = unif_rand!T();
376 	n_uv++;
377 	if(n_uv >= 10000) {
378 	    //REprintf("rhyper() branch III: giving up after %d rejections", n_uv);
379 	    return T.nan;
380     }
381     //#ifdef DEBUG_rhyper
382     //	REprintf(" ... L30: new (u=%g, v ~ U[0,1])[%d]\n", u, n_uv);
383     //#endif
384 
385 	if (u < p1) {		/* rectangular region */
386 	    ix = cast(int) (xl + u);
387 	} else if (u <= p2) {	/* left tail */
388 	    ix = cast(int) (xl + log(v) / lamdl);
389 	    if (ix < minjx)
390 		goto L30;
391 	    v = v * (u - p1) * lamdl;
392 	} else {		/* right tail */
393 	    ix = cast(int) (xr - log(v) / lamdr);
394 	    if (ix > maxjx)
395 		goto L30;
396 	    v = v * (u - p2) * lamdr;
397 	}
398 
399 	/* acceptance/rejection test */
400 	//Rboolean
401 	int reject = 1;
402 
403 	if (m < 100 || ix <= 50) {
404 	    /* explicit evaluation */
405 	    /* The original algorithm (and TOMS 668) have
406 		   f = f * i * (n2 - k + i) / (n1 - i) / (k - i);
407 	       in the (m > ix) case, but the definition of the
408 	       recurrence relation on p134 shows that the +1 is
409 	       needed. */
410 	    int i;
411 	    T f = 1.0;
412 	    if (m < ix) {
413 		for (i = m + 1; i <= ix; i++)
414 		    f = f * (n1 - i + 1) * (k - i + 1) / (n2 - k + i) / i;
415 	    } else if (m > ix) {
416 		for (i = ix + 1; i <= m; i++)
417 		    f = f * i * (n2 - k + i) / (n1 - i + 1) / (k - i + 1);
418 	    }
419 	    if (v <= f) {
420 		reject = 0;
421 	    }
422 	} else {
423 
424 	    const static T deltal = 0.0078;
425 	    const static T deltau = 0.0034;
426 
427 	    T e, g, r, t, y;
428 	    T de, dg, dr, ds, dt, gl, gu, nk, nm, ub;
429 	    T xk, xm, xn, y1, ym, yn, yk, alv;
430 
431     //#ifdef DEBUG_rhyper
432     //	    REprintf(" ... accept/reject 'large' case v=%g\n", v);
433     //#endif
434 	    /* squeeze using upper and lower bounds */
435 	    y = ix;
436 	    y1 = y + 1.0;
437 	    ym = y - m;
438 	    yn = n1 - y + 1.0;
439 	    yk = k - y + 1.0;
440 	    nk = n2 - k + y1;
441 	    r = -ym / y1;
442 	    s = ym / yn;
443 	    t = ym / yk;
444 	    e = -ym / nk;
445 	    g = yn * yk / (y1 * nk) - 1.0;
446 	    dg = 1.0;
447 	    if (g < 0.0)
448 		dg = 1.0 + g;
449 	    gu = g * (1.0 + g * (-0.5 + g / 3.0));
450 	    gl = gu - .25 * (g * g * g * g) / dg;
451 	    xm = m + 0.5;
452 	    xn = n1 - m + 0.5;
453 	    xk = k - m + 0.5;
454 	    nm = n2 - k + xm;
455 	    ub = y * gu - m * gl + deltau
456 		+ xm * r * (1. + r * (-0.5 + r / 3.0))
457 		+ xn * s * (1. + s * (-0.5 + s / 3.0))
458 		+ xk * t * (1. + t * (-0.5 + t / 3.0))
459 		+ nm * e * (1. + e * (-0.5 + e / 3.0));
460 	    /* test against upper bound */
461 	    alv = log(v);
462 	    if (alv > ub) {
463 		    reject = 1;
464 	    } else {
465 		    		/* test against lower bound */
466 		    dr = xm * (r * r * r * r);
467 		    if (r < 0.0)
468 		        dr /= (1.0 + r);
469 		    ds = xn * (s * s * s * s);
470 		    if (s < 0.0)
471 		        ds /= (1.0 + s);
472 		    dt = xk * (t * t * t * t);
473 		    if (t < 0.0)
474 		        dt /= (1.0 + t);
475 		    de = nm * (e * e * e * e);
476 		    if (e < 0.0)
477 		        de /= (1.0 + e);
478 		    if (alv < ub - 0.25 * (dr + ds + dt + de)
479 		        + (y + m) * (gl - gu) - deltal) {
480 		        reject = 0;
481 		    }
482 		    else {
483 		        /* * Stirling's formula to machine accuracy
484 		         */
485 		        if (alv <= (a - afc!T(ix) - afc!T(n1 - ix)
486 		    		- afc!T(k - ix) - afc!T(n2 - k + ix))) {
487 		    	reject = 0;
488 		        } else {
489 		    	reject = 1;
490 		        }
491 		    }
492 	    }
493 	} // else
494 	if (reject)
495 	    goto L30;
496     }
497 
498 
499 L_finis:
500     /* return appropriate variate */
501 
502     if (kk + kk >= tn) {
503 	if (nn1 > nn2) {
504 	    ix = kk - nn2 + ix;
505 	} else {
506 	    ix = nn1 - ix;
507 	}
508     } else {
509 	if (nn1 > nn2)
510 	    ix = kk - ix;
511     }
512     return ix;
513 }
514 
515 
516 
517 void test_hyper()
518 {
519 	import std.stdio: writeln;
520 	writeln("dhyper: ", dhyper(1., 6., 3., 2., 0));
521 	writeln("phyper: ", phyper(1., 6., 3., 2., 1, 0));
522 	writeln("qhyper: ", qhyper(.7, 6., 3., 3., 1, 0));
523 	writeln("rhyper: ", rhyper(4., 5., 5.), ", rhyper: ", rhyper(4., 5., 5.), ", rhyper: ", rhyper(4., 5., 5.));
524 }
525