1 module rmathd.hyper; 2 3 public import rmathd.common; 4 public import rmathd.choose; 5 public import rmathd.binomial; 6 7 8 /* 9 ** normal.d poisson.d exponential.d 10 ** dmd hyper.d common.d choose.d binomial.d && ./hyper 11 */ 12 13 14 T dhyper(T: double)(T x, T r, T b, T n, int give_log) 15 { 16 T p, q, p1, p2, p3; 17 18 mixin R_D__0!give_log; 19 mixin R_D__1!give_log; 20 if (isNaN(x) || isNaN(r) || isNaN(b) || isNaN(n)) 21 return x + r + b + n; 22 23 if (R_D_negInonint!T(r) || R_D_negInonint!T(b) || R_D_negInonint!T(n) || n > r + b) 24 return T.nan; 25 if(x < 0) 26 return(R_D__0); 27 mixin (R_D_nonint_check!(x));// incl warning 28 29 x = nearbyint(x); 30 r = nearbyint(r); 31 b = nearbyint(b); 32 n = nearbyint(n); 33 34 if (n < x || r < x || n - x > b) 35 return R_D__0; 36 if (n == 0) 37 return((x == 0) ? R_D__1 : R_D__0); 38 39 p = (cast(T)n)/(cast(T)(r + b)); 40 q = (cast(T)(r + b - n))/(cast(T)(r + b)); 41 42 p1 = dbinom_raw!T(x, r, p, q, give_log); 43 p2 = dbinom_raw!T(n - x, b, p, q, give_log); 44 p3 = dbinom_raw!T(n, r + b, p, q, give_log); 45 46 return( (give_log) ? p1 + p2 - p3 : p1*p2/p3 ); 47 } 48 49 50 static T pdhyper(T)(T x, T NR, T NB, T n, int log_p) 51 { 52 /* 53 * Calculate 54 * 55 * phyper (x, NR, NB, n, TRUE, FALSE) 56 * [log] ---------------------------------- 57 * dhyper (x, NR, NB, n, FALSE) 58 * 59 * without actually calling phyper. This assumes that 60 * 61 * x * (NR + NB) <= n * NR 62 * 63 */ 64 real sum = 0; 65 real term = 1; 66 67 while (x > 0 && term >= DBL_EPSILON * sum) { 68 term *= x * (NB - n + x) / (n + 1 - x) / (NR + 1 - x); 69 sum += term; 70 x--; 71 } 72 73 T ss = cast(T) sum; 74 return log_p ? log1p(ss) : 1 + ss; 75 } 76 77 78 /* FIXME: The old phyper() code was basically used in ./qhyper.c as well 79 * ----- We need to sync this again! 80 */ 81 T phyper(T: double)(T x, T NR, T NB, T n, int lower_tail, int log_p) 82 { 83 /* Sample of n balls from NR red and NB black ones; x are red */ 84 85 T d, pd; 86 87 if(isNaN(x) || isNaN(NR) || isNaN(NB) || isNaN(n)) 88 return x + NR + NB + n; 89 90 x = floor (x + 1e-7); 91 NR = nearbyint(NR); 92 NB = nearbyint(NB); 93 n = nearbyint(n); 94 95 if (NR < 0 || NB < 0 || !isFinite(NR + NB) || n < 0 || n > NR + NB) 96 return T.nan; 97 98 if (x * (NR + NB) > n * NR) { 99 /* Swap tails. */ 100 T oldNB = NB; 101 NB = NR; 102 NR = oldNB; 103 x = n - x - 1; 104 lower_tail = !lower_tail; 105 } 106 107 if (x < 0) 108 return R_DT_0!T(lower_tail, log_p); 109 if (x >= NR || x >= n) 110 return R_DT_1!T(lower_tail, log_p); 111 112 d = dhyper!T(x, NR, NB, n, log_p); 113 pd = pdhyper!T(x, NR, NB, n, log_p); 114 115 return log_p ? R_DT_Log!T(d + pd, lower_tail) : R_D_Lval!T(d*pd, lower_tail); 116 } 117 118 119 T qhyper(T: double)(T p, T NR, T NB, T n, int lower_tail, int log_p) 120 { 121 /* This is basically the same code as ./phyper.c *used* to be --> FIXME! */ 122 T N, xstart, xend, xr, xb, sum, term; 123 int small_N; 124 125 if (isNaN(p) || isNaN(NR) || isNaN(NB) || isNaN(n)) 126 return p + NR + NB + n; 127 128 if(!isFinite(p) || !isFinite(NR) || !isFinite(NB) || !isFinite(n)) 129 return T.nan; 130 131 NR = nearbyint(NR); 132 NB = nearbyint(NB); 133 N = NR + NB; 134 n = nearbyint(n); 135 if (NR < 0 || NB < 0 || n < 0 || n > N) 136 return T.nan; 137 138 /* Goal: Find xr (= #{red balls in sample}) such that 139 * phyper(xr, NR,NB, n) >= p > phyper(xr - 1, NR,NB, n) 140 */ 141 142 xstart = fmax2!T(0, n - NB); 143 xend = fmin2!T(n, NR); 144 145 mixin (R_Q_P01_boundaries!(p, xstart, xend)); 146 147 xr = xstart; 148 xb = n - xr;/* always ( = #{black balls in sample} ) */ 149 150 small_N = (N < 1000); /* won't have underflow in product below */ 151 /* if N is small, term := product.ratio( bin.coef ); 152 otherwise work with its logarithm to protect against underflow */ 153 term = lfastchoose!T(NR, xr) + lfastchoose!T(NB, xb) - lfastchoose!T(N, n); 154 if(small_N) 155 term = exp(term); 156 NR -= xr; 157 NB -= xb; 158 159 mixin R_DT_qIv!p; 160 if(!lower_tail || log_p) { 161 p = R_DT_qIv; 162 } 163 p *= 1 - 1000*DBL_EPSILON; /* was 64, but failed on FreeBSD sometimes */ 164 sum = small_N ? term : exp(term); 165 166 while(sum < p && xr < xend) { 167 xr++; 168 NB++; 169 if (small_N) term *= (NR / xr) * (xb / NB); 170 else term += log((NR / xr) * (xb / NB)); 171 sum += small_N ? term : exp(term); 172 xb--; 173 NR--; 174 } 175 return xr; 176 } 177 178 179 static T afc(T)(int i) 180 { 181 // If (i > 7), use Stirling's approximation, otherwise use table lookup. 182 const static T[8] al = [ 183 0.0,/*ln(0!)=ln(1)*/ 184 0.0,/*ln(1!)=ln(1)*/ 185 0.69314718055994530941723212145817,/*ln(2) */ 186 1.79175946922805500081247735838070,/*ln(6) */ 187 3.17805383034794561964694160129705,/*ln(24)*/ 188 4.78749174278204599424770093452324, 189 6.57925121201010099506017829290394, 190 8.52516136106541430016553103634712 191 /* 10.60460290274525022841722740072165, approx. value below = 192 10.6046028788027; rel.error = 2.26 10^{-9} 193 194 FIXME: Use constants and if(n > ..) decisions from ./stirlerr.c 195 ----- will be even *faster* for n > 500 (or so) 196 */ 197 ]; 198 199 if (i < 0) { 200 //MATHLIB_WARNING(("rhyper.c: afc(i), i=%d < 0 -- SHOULD NOT HAPPEN!\n"), i); 201 return -1; // unreached 202 } 203 if (i <= 7) 204 return al[i]; 205 // else i >= 8 : 206 T di = i, i2 = di*di; 207 return (di + 0.5) * log(di) - di + M_LN_SQRT_2PI + (0.0833333333333333 - 0.00277777777777778 / i2) / di; 208 } 209 210 211 // rhyper(NR, NB, n) -- NR 'red', NB 'blue', n drawn, how many are 'red' 212 T rhyper(T: double)(T nn1in, T nn2in, T kkin) 213 { 214 /* extern double afc(int); */ 215 216 int nn1, nn2, kk; 217 int ix; // return value (coerced to double at the very end) 218 //Rboolean 219 int setup1, setup2; 220 221 /* These should become 'thread_local globals' : */ 222 static int ks = -1, n1s = -1, n2s = -1; 223 static int m, minjx, maxjx; 224 static int k, n1, n2; // <- not allowing larger integer par 225 static T tn; 226 227 // II : 228 static T w; 229 // III: 230 static T a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3; 231 232 /* check parameter validity */ 233 234 if(!isFinite(nn1in) || !isFinite(nn2in) || !isFinite(kkin)) 235 return T.nan; 236 237 nn1in = nearbyint(nn1in); 238 nn2in = nearbyint(nn2in); 239 kkin = nearbyint(kkin); 240 241 if (nn1in < 0 || nn2in < 0 || kkin < 0 || kkin > nn1in + nn2in) 242 return T.nan; 243 if (nn1in >= INT_MAX || nn2in >= INT_MAX || kkin >= INT_MAX) { 244 /* large n -- evade integer overflow (and inappropriate algorithms) 245 -------- */ 246 // FIXME: Much faster to give rbinom() approx when appropriate; -> see Kuensch(1989) 247 // Johnson, Kotz,.. p.258 (top) mention the *four* different binomial approximations 248 if(kkin == 1.) { // Bernoulli 249 return rbinom!T(kkin, nn1in / (nn1in + nn2in)); 250 } 251 // Slow, but safe: return F^{-1}(U) where F(.) = phyper(.) and U ~ U[0,1] 252 return qhyper!T(unif_rand!T(), nn1in, nn2in, kkin, 0, 0); 253 } 254 nn1 = cast(int)nn1in; 255 nn2 = cast(int)nn2in; 256 kk = cast(int)kkin; 257 258 /* if new parameter values, initialize */ 259 if (nn1 != n1s || nn2 != n2s) { 260 setup1 = 1; setup2 = 1; 261 } else if (kk != ks) { 262 setup1 = 0; setup2 = 1; 263 } else { 264 setup1 = 0; setup2 = 0; 265 } 266 if (setup1) { 267 n1s = nn1; 268 n2s = nn2; 269 tn = nn1 + nn2; 270 if (nn1 <= nn2) { 271 n1 = nn1; 272 n2 = nn2; 273 } else { 274 n1 = nn2; 275 n2 = nn1; 276 } 277 } 278 if (setup2) { 279 ks = kk; 280 if (kk + kk >= tn) { 281 k = cast(int)(tn - kk); 282 } else { 283 k = kk; 284 } 285 } 286 if (setup1 || setup2) { 287 m = cast(int) ((k + 1.) * (n1 + 1.) / (tn + 2.)); 288 minjx = imax2(0, k - n2); 289 maxjx = imin2(n1, k); 290 //#ifdef DEBUG_rhyper 291 // REprintf("rhyper(nn1=%d, nn2=%d, kk=%d), setup: floor(mean)= m=%d, jx in (%d..%d)\n", 292 // nn1, nn2, kk, m, minjx, maxjx); 293 //#endif 294 } 295 /* generate random variate --- Three basic cases */ 296 297 if (minjx == maxjx) { /* I: degenerate distribution ---------------- */ 298 //#ifdef DEBUG_rhyper 299 // REprintf("rhyper(), branch I (degenerate)\n"); 300 //#endif 301 ix = maxjx; 302 goto L_finis; // return appropriate variate 303 304 } else if (m - minjx < 10) { // II: (Scaled) algorithm HIN (inverse transformation) ---- 305 const static T scale = 1e25; // scaling factor against (early) underflow 306 const static T con = 57.5646273248511421; 307 // 25*log(10) = log(scale) { <==> exp(con) == scale } 308 if (setup1 || setup2) { 309 T lw; // log(w); w = exp(lw) * scale = exp(lw + log(scale)) = exp(lw + con) 310 if (k < n2) { 311 lw = afc!T(n2) + afc!T(n1 + n2 - k) - afc!T(n2 - k) - afc!T(n1 + n2); 312 } else { 313 lw = afc!T(n1) + afc!T( k ) - afc!T(k - n2) - afc!T(n1 + n2); 314 } 315 w = exp(lw + con); 316 } 317 T p, u; 318 //#ifdef DEBUG_rhyper 319 // REprintf("rhyper(), branch II; w = %g > 0\n", w); 320 //#endif 321 L10: 322 p = w; 323 ix = minjx; 324 u = unif_rand!T() * scale; 325 //#ifdef DEBUG_rhyper 326 // REprintf(" _new_ u = %g\n", u); 327 //#endif 328 while (u > p) { 329 u -= p; 330 p *= (cast(T) n1 - ix) * (k - ix); 331 ix++; 332 p = p / ix / (n2 - k + ix); 333 //#ifdef DEBUG_rhyper 334 // REprintf(" ix=%3d, u=%11g, p=%20.14g (u-p=%g)\n", ix, u, p, u-p); 335 //#endif 336 if (ix > maxjx) 337 goto L10; 338 // FIXME if(p == 0.) we also "have lost" => goto L10 339 } 340 } else { /* III : H2PE Algorithm --------------------------------------- */ 341 342 T u,v; 343 344 if (setup1 || setup2) { 345 s = sqrt((tn - k) * k * n1 * n2 / (tn - 1) / tn / tn); 346 347 /* remark: d is defined in reference without int. */ 348 /* the truncation centers the cell boundaries at 0.5 */ 349 350 d = cast(int) (1.5 * s) + .5; 351 xl = m - d + .5; 352 xr = m + d + .5; 353 a = afc!T(m) + afc!T(n1 - m) + afc!T(k - m) + afc!T(n2 - k + m); 354 kl = exp(a - afc!T(cast(int) (xl)) - afc!T(cast(int) (n1 - xl)) 355 - afc!T(cast(int) (k - xl)) 356 - afc!T(cast(int) (n2 - k + xl))); 357 kr = exp(a - afc!T(cast(int) (xr - 1)) 358 - afc!T(cast(int) (n1 - xr + 1)) 359 - afc!T(cast(int) (k - xr + 1)) 360 - afc!T(cast(int) (n2 - k + xr - 1))); 361 lamdl = -log(xl * (n2 - k + xl) / (n1 - xl + 1) / (k - xl + 1)); 362 lamdr = -log((n1 - xr + 1) * (k - xr + 1) / xr / (n2 - k + xr)); 363 p1 = d + d; 364 p2 = p1 + kl / lamdl; 365 p3 = p2 + kr / lamdr; 366 } 367 //#ifdef DEBUG_rhyper 368 // REprintf("rhyper(), branch III {accept/reject}: (xl,xr)= (%g,%g); (lamdl,lamdr)= (%g,%g)\n", 369 // xl, xr, lamdl,lamdr); 370 // REprintf("-------- p123= c(%g,%g,%g)\n", p1,p2, p3); 371 //#endif 372 int n_uv = 0; 373 L30: 374 u = unif_rand!T() * p3; 375 v = unif_rand!T(); 376 n_uv++; 377 if(n_uv >= 10000) { 378 //REprintf("rhyper() branch III: giving up after %d rejections", n_uv); 379 return T.nan; 380 } 381 //#ifdef DEBUG_rhyper 382 // REprintf(" ... L30: new (u=%g, v ~ U[0,1])[%d]\n", u, n_uv); 383 //#endif 384 385 if (u < p1) { /* rectangular region */ 386 ix = cast(int) (xl + u); 387 } else if (u <= p2) { /* left tail */ 388 ix = cast(int) (xl + log(v) / lamdl); 389 if (ix < minjx) 390 goto L30; 391 v = v * (u - p1) * lamdl; 392 } else { /* right tail */ 393 ix = cast(int) (xr - log(v) / lamdr); 394 if (ix > maxjx) 395 goto L30; 396 v = v * (u - p2) * lamdr; 397 } 398 399 /* acceptance/rejection test */ 400 //Rboolean 401 int reject = 1; 402 403 if (m < 100 || ix <= 50) { 404 /* explicit evaluation */ 405 /* The original algorithm (and TOMS 668) have 406 f = f * i * (n2 - k + i) / (n1 - i) / (k - i); 407 in the (m > ix) case, but the definition of the 408 recurrence relation on p134 shows that the +1 is 409 needed. */ 410 int i; 411 T f = 1.0; 412 if (m < ix) { 413 for (i = m + 1; i <= ix; i++) 414 f = f * (n1 - i + 1) * (k - i + 1) / (n2 - k + i) / i; 415 } else if (m > ix) { 416 for (i = ix + 1; i <= m; i++) 417 f = f * i * (n2 - k + i) / (n1 - i + 1) / (k - i + 1); 418 } 419 if (v <= f) { 420 reject = 0; 421 } 422 } else { 423 424 const static T deltal = 0.0078; 425 const static T deltau = 0.0034; 426 427 T e, g, r, t, y; 428 T de, dg, dr, ds, dt, gl, gu, nk, nm, ub; 429 T xk, xm, xn, y1, ym, yn, yk, alv; 430 431 //#ifdef DEBUG_rhyper 432 // REprintf(" ... accept/reject 'large' case v=%g\n", v); 433 //#endif 434 /* squeeze using upper and lower bounds */ 435 y = ix; 436 y1 = y + 1.0; 437 ym = y - m; 438 yn = n1 - y + 1.0; 439 yk = k - y + 1.0; 440 nk = n2 - k + y1; 441 r = -ym / y1; 442 s = ym / yn; 443 t = ym / yk; 444 e = -ym / nk; 445 g = yn * yk / (y1 * nk) - 1.0; 446 dg = 1.0; 447 if (g < 0.0) 448 dg = 1.0 + g; 449 gu = g * (1.0 + g * (-0.5 + g / 3.0)); 450 gl = gu - .25 * (g * g * g * g) / dg; 451 xm = m + 0.5; 452 xn = n1 - m + 0.5; 453 xk = k - m + 0.5; 454 nm = n2 - k + xm; 455 ub = y * gu - m * gl + deltau 456 + xm * r * (1. + r * (-0.5 + r / 3.0)) 457 + xn * s * (1. + s * (-0.5 + s / 3.0)) 458 + xk * t * (1. + t * (-0.5 + t / 3.0)) 459 + nm * e * (1. + e * (-0.5 + e / 3.0)); 460 /* test against upper bound */ 461 alv = log(v); 462 if (alv > ub) { 463 reject = 1; 464 } else { 465 /* test against lower bound */ 466 dr = xm * (r * r * r * r); 467 if (r < 0.0) 468 dr /= (1.0 + r); 469 ds = xn * (s * s * s * s); 470 if (s < 0.0) 471 ds /= (1.0 + s); 472 dt = xk * (t * t * t * t); 473 if (t < 0.0) 474 dt /= (1.0 + t); 475 de = nm * (e * e * e * e); 476 if (e < 0.0) 477 de /= (1.0 + e); 478 if (alv < ub - 0.25 * (dr + ds + dt + de) 479 + (y + m) * (gl - gu) - deltal) { 480 reject = 0; 481 } 482 else { 483 /* * Stirling's formula to machine accuracy 484 */ 485 if (alv <= (a - afc!T(ix) - afc!T(n1 - ix) 486 - afc!T(k - ix) - afc!T(n2 - k + ix))) { 487 reject = 0; 488 } else { 489 reject = 1; 490 } 491 } 492 } 493 } // else 494 if (reject) 495 goto L30; 496 } 497 498 499 L_finis: 500 /* return appropriate variate */ 501 502 if (kk + kk >= tn) { 503 if (nn1 > nn2) { 504 ix = kk - nn2 + ix; 505 } else { 506 ix = nn1 - ix; 507 } 508 } else { 509 if (nn1 > nn2) 510 ix = kk - ix; 511 } 512 return ix; 513 } 514 515 516 517 void test_hyper() 518 { 519 import std.stdio: writeln; 520 writeln("dhyper: ", dhyper(1., 6., 3., 2., 0)); 521 writeln("phyper: ", phyper(1., 6., 3., 2., 1, 0)); 522 writeln("qhyper: ", qhyper(.7, 6., 3., 3., 1, 0)); 523 writeln("rhyper: ", rhyper(4., 5., 5.), ", rhyper: ", rhyper(4., 5., 5.), ", rhyper: ", rhyper(4., 5., 5.)); 524 } 525