diff --git a/sys/src/libmp/ntest.c b/sys/src/libmp/ntest.c index 820eaa8a7..688c9df7d 100644 --- a/sys/src/libmp/ntest.c +++ b/sys/src/libmp/ntest.c @@ -1,3 +1,42 @@ +/* + +tests missing for: + + mpint* strtomp(char *buf, char **rptr, int base, mpint *b) + char* mptoa(mpint *b, int base, char *buf, int blen) + mpint* betomp(uchar *buf, uint blen, mpint *b) + int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp) + void mptober(mpint *b, uchar *buf, int blen) + mpint* letomp(uchar *buf, uint blen, mpint *b) + int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp) + void mptolel(mpint *b, uchar *buf, int blen) + uint mptoui(mpint*) + mpint* uitomp(uint, mpint*) + int mptoi(mpint*) + mpint* itomp(int, mpint*) + mpint* vtomp(vlong, mpint*) + vlong mptov(mpint*) + mpint* uvtomp(uvlong, mpint*) + uvlong mptouv(mpint*) + mpint* dtomp(double, mpint*) + double mptod(mpint*) + void mpexp(mpint *b, mpint *e, mpint *m, mpint *res) + void mpmod(mpint *b, mpint *m, mpint *remainder) + void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) + void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) + void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) + void mpsel(int s, mpint *b1, mpint *b2, mpint *res) + void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y) + void mpinvert(mpint *b, mpint *m, mpint *res) + void mpdigdiv(mpdigit *dividend, mpdigit divisor, mpdigit *quotient) + void mpvecadd(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *sum) + void mpvecsub(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *diff) + void mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p) + int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p) + void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen,mpdigit *p) + int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen) +*/ + #include #include #include @@ -40,6 +79,9 @@ ldnorm(ldint *a) if(a->b[i] != a->b[a->n-1]) break; ldbits(a, i + 2); + }else{ + ldbits(a, 1); + a->b[0] = 0; } return a; } @@ -55,6 +97,10 @@ ldneg(ldint *a) a->b[i] = c & 1; c >>= 1; } + if(c != a->b[a->n - 1]){ + ldbits(a, a->n + 1); + a->b[a->n - 1] = c; + } } static int @@ -69,6 +115,7 @@ ldnew(int n) ldint *a; a = malloc(sizeof(ldint)); + if(n <= 0) n = 1; a->b = malloc(n); a->n = n; return a; @@ -87,6 +134,7 @@ ldsanity(ldint *a) { int i; + assert(a->n > 0); for(i = 0; i < a->n; i++) assert(a->b[i] < 2); } @@ -177,6 +225,7 @@ pow2told(int n, ldint *a) memset(a->b, 0, k+2); a->b[k] = 1; if(n < 0) ldneg(a); + ldnorm(a); return a; } @@ -206,6 +255,74 @@ ldfmt(Fmt *f) return 0; } +static int +mpdetfmt(Fmt *f) +{ + mpint *a; + int i, j; + + a = va_arg(f->args, mpint *); + fmtprint(f, "(sign=%d,top=%d,size=%d,", a->sign, a->top, a->size); + for(i=0;itop;){ + fmtprint(f, "%ullx", (uvlong)a->p[i]); + if(++i == a->top) break; + fmtrune(f, ','); + for(j = i+1; j < a->top; j++) + if(a->p[i] != a->p[j]) + goto next; + fmtprint(f, "..."); + break; + next:; + } + fmtrune(f, '|'); + for(i=a->top;isize;){ + fmtprint(f, "%ullx", (uvlong)a->p[i]); + if(++i == a->size) break; + fmtrune(f, ','); + for(j = i+1; j < a->top; j++) + if(a->p[i] != a->p[j]) + goto next2; + fmtprint(f, "..."); + break; + next2:; + } + fmtrune(f, ')'); + return 0; +} + +static int +ldcmp(ldint *a, ldint *b) +{ + int x, y; + int i, r; + + r = max(a->n, b->n); + if(a->b[a->n-1] != b->b[b->n-1]) + return b->b[b->n - 1] - a->b[a->n - 1]; + for(i = r - 1; --i >= 0; ){ + x = ldget(a, i); + y = ldget(b, i); + if(x != y) + return x - y; + } + return 0; +} + +static int +ldmagcmp(ldint *a, ldint *b) +{ + int s1, s2, r; + + s1 = a->b[a->n - 1]; + s2 = b->b[b->n - 1]; + if(s1) ldneg(a); + if(s2) ldneg(b); + r = ldcmp(a, b); + if(s1) ldneg(a); + if(s2) ldneg(b); + return r; +} + static int ldmpeq(ldint *a, mpint *b) { @@ -265,6 +382,50 @@ ldadd(ldint *a, ldint *b, ldint *q) ldnorm(q); } +static void +ldmagadd(ldint *a, ldint *b, ldint *q) +{ + int i, r, s1, s2, c1, c2, co; + + r = max(a->n, b->n) + 2; + ldbits(q, r); + co = 0; + s1 = c1 = a->b[a->n - 1] & 1; + s2 = c2 = b->b[b->n - 1] & 1; + for(i = 0; i < r; i++){ + c1 += s1 ^ ldget(a, i) & 1; + c2 += s2 ^ ldget(b, i) & 1; + co += (c1 & 1) + (c2 & 1); + q->b[i] = co & 1; + co >>= 1; + c1 >>= 1; + c2 >>= 1; + } + ldnorm(q); +} + +static void +ldmagsub(ldint *a, ldint *b, ldint *q) +{ + int i, r, s1, s2, c1, c2, co; + + r = max(a->n, b->n) + 2; + ldbits(q, r); + co = 0; + s1 = c1 = a->b[a->n - 1] & 1; + s2 = c2 = 1 ^ b->b[b->n - 1] & 1; + for(i = 0; i < r; i++){ + c1 += s1 ^ ldget(a, i) & 1; + c2 += s2 ^ ldget(b, i) & 1; + co += (c1 & 1) + (c2 & 1); + q->b[i] = co & 1; + co >>= 1; + c1 >>= 1; + c2 >>= 1; + } + ldnorm(q); +} + static void ldsub(ldint *a, ldint *b, ldint *q) { @@ -474,6 +635,108 @@ ldxor(ldint *a, ldint *b, ldint *q) ldnorm(q); } +static void +ldleft(ldint *a, int n, ldint *b) +{ + int i, c; + + if(n < 0){ + if(a->n <= -n){ + b->n = 0; + ldnorm(b); + return; + } + c = 0; + if(a->b[a->n - 1]) + for(i = 0; i < -n; i++) + if(a->b[i]){ + c = 1; + break; + } + ldbits(b, a->n + n); + for(i = 0; i < a->n + n; i++){ + c += a->b[i - n] & 1; + b->b[i] = c & 1; + c >>= 1; + } + }else{ + ldbits(b, a->n + n); + memmove(b->b + n, a->b, a->n); + memset(b->b, 0, n); + } + ldnorm(b); +} + +static void +ldasr(ldint *a, int n, ldint *b) +{ + if(n < 0){ + ldleft(a, -n, b); + return; + } + if(a->n <= n){ + ldbits(b, 1); + b->b[0] = a->b[a->n - 1]; + return; + } + ldbits(b, a->n - n); + memmove(b->b, a->b + n, a->n - n); + ldnorm(b); +} + +static void +ldtrunc(ldint *a, int n, ldint *b) +{ + ldbits(b, n+1); + b->b[n] = 0; + if(a->n >= n) + memmove(b->b, a->b, n); + else{ + memmove(b->b, a->b, a->n); + memset(b->b + a->n, a->b[a->n - 1], n - a->n); + } + ldnorm(b); +} + +static void +ldxtend(ldint *a, int n, ldint *b) +{ + ldbits(b, n); + if(a->n >= n) + memmove(b->b, a->b, n); + else{ + memmove(b->b, a->b, a->n); + memset(b->b + a->n, a->b[a->n - 1], n - a->n); + } + ldnorm(b); +} + +static void +mpnot_(mpint *a, int, mpint *b) +{ + mpnot(a, b); +} + +static void +ldnot(ldint *a, int, ldint *b) +{ + int i; + + ldbits(b, a->n); + for(i = 0; i < a->n; i++) + b->b[i] = a->b[i] ^ 1; +} + +enum { NTEST = 2*257 }; +static void +testgen(int i, ldint *a) +{ + if(i < 257) + itold(i-128, a); + else + pow2told(i-385, a); +} + typedef struct Test2 Test2; struct Test2 { char *name; @@ -481,6 +744,14 @@ struct Test2 { void (*ref)(ldint *, ldint *, ldint *); }; +typedef struct Test1i Test1i; +struct Test1i { + char *name; + enum { NONEG = 1 } flags; + void (*dut)(mpint *, int, mpint *); + void (*ref)(ldint *, int, ldint *); +}; + int validate(char *name, ldint *ex, mpint *res, char *str) { @@ -564,31 +835,21 @@ run2(Test2 *t) b = ldnew(32); c = ldnew(32); ok = 1; - for(i = -128; i <= 128; i++) - for(j = -128; j <= 128; j++){ - itold(i, a); - itold(j, b); + for(i = 0; i < NTEST; i++){ + for(j = 0; j < NTEST; j++){ + testgen(i, a); + testgen(j, b); ok &= test2(t, a, b); - pow2told(i, a); - itold(j, b); - ok &= test2(t, a, b); - ok &= test2(t, b, a); - pow2told(i, a); - pow2told(j, b); - ok &= test2(t, a, b); } + itold(i, a); + ok &= test2x(t, a); + } for(i = 1; i <= 4; i++) for(j = 1; j <= 4; j++){ ldrand(i * Dbits, a); ldrand(j * Dbits, b); ok &= test2(t, a, b); } - for(i = -128; i <= 128; i++){ - itold(i, a); - ok &= test2x(t, a); - pow2told(i, a); - ok &= test2x(t, a); - } ldfree(a); ldfree(b); if(ok) @@ -596,15 +857,17 @@ run2(Test2 *t) } Test2 tests2[] = { - "mpdiv(q)", mpdivq, lddivq, - "mpdiv(r)", mpdivr, lddivr, - "mpmul", mpmul, ldmul, "mpadd", mpadd, ldadd, + "mpmagadd", mpmagadd, ldmagadd, "mpsub", mpsub, ldsub, + "mpmagsub", mpmagsub, ldmagsub, "mpand", mpand, ldand, "mpor", mpor, ldor, "mpbic", mpbic, ldbic, "mpxor", mpxor, ldxor, + "mpmul", mpmul, ldmul, + "mpdiv(q)", mpdivq, lddivq, + "mpdiv(r)", mpdivr, lddivr, }; void @@ -616,10 +879,168 @@ all2(void) run2(t); } +int +test1i(Test1i *t, ldint *a, int b) +{ + ldint *c; + mpint *ma, *rc; + int rv; + + c = ldnew(0); + t->ref(a, b, c); + ldsanity(a); + ldsanity(c); + ma = ldtomp(a, nil); + rc = mptarget(); + t->dut(ma, b, rc); + rv = validate(t->name, c, rc, smprint("%L and %d", a, b)); + ldtomp(a, ma); + t->dut(ma, b, ma); + rv = validate(t->name, c, ma, smprint("%L (aliased to result) and %d", a, b)); + ldfree(c); + mpfree(rc); + mpfree(ma); + return rv; +} + +void +run1i(Test1i *t) +{ + int i, j, ok; + ldint *a, *c; + + a = ldnew(32); + c = ldnew(32); + ok = 1; + for(i = 0; i < NTEST; i++) + for(j = (t->flags & NONEG) != 0 ? 0 : -128; j <= 128; j++){ + testgen(i, a); + ok &= test1i(t, a, j); + } + ldfree(a); + ldfree(c); + if(ok) + fprint(2, "%s: passed\n", t->name); +} + + +Test1i tests1i[] = { + "mpleft", 0, mpleft, ldleft, + "mpasr", 0, mpasr, ldasr, + "mptrunc", NONEG, mptrunc, ldtrunc, + "mpxtend", NONEG, mpxtend, ldxtend, + "mpnot", NONEG, mpnot_, ldnot, /* hack */ +}; + +void +all1i(void) +{ + Test1i *t; + + for(t = tests1i; t < tests1i + nelem(tests1i); t++) + run1i(t); +} + +void +siglo(void) +{ + int i, j, k; + ldint *a; + mpint *ma; + int sigok, lowok0; + + a = ldnew(32); + ma = mpnew(0); + sigok = 1; + lowok0 = 1; + for(i = 0; i < NTEST; i++){ + testgen(i, a); + for(j = 0; j < a->n; j++) + if(a->b[j] != 0) + break; + if(j == a->n) j = 0; + ldtomp(a, ma); + k = mplowbits0(ma); + if(k != j){ + fprint(2, "FAIL: mplowbits0: %#B: got %d, expected %d\n", ma, k, j); + lowok0 = 0; + } + for(j = a->n - 2; j >= 0; j--) + if(a->b[j] != a->b[a->n-1]) + break; + for(k = j-1; k >= 0; k--) + if(a->b[k] != 0) + break; + if(a->b[a->n - 1] && k < 0) j++; + j++; + ldtomp(a, ma); + k = mpsignif(ma); + if(k != j){ + fprint(2, "FAIL: mpsignif: %#B: got %d, expected %d\n", ma, k, j); + sigok = 0; + } + } + if(sigok) fprint(2, "mpsignif: passed\n"); + if(lowok0) fprint(2, "mplowbits0: passed\n"); + ldfree(a); + mpfree(ma); +} + +void +cmptest(void) +{ + int i, j, k, l; + ldint *a, *b; + mpint *ma, *mb; + int cmpok, magcmpok; + + a = ldnew(32); + b = ldnew(32); + ma = mpnew(0); + mb = mpnew(0); + cmpok = 1; + magcmpok = 1; + for(i = 0; i < NTEST; i++) + for(j = 0; j < NTEST; j++){ + testgen(i, a); + testgen(j, b); + ldtomp(a, ma); + ldtomp(b, mb); + l = ldcmp(a, b); + k = mpcmp(ma, mb); + if(k < 0) k = -1; + if(k > 0) k = 1; + if(k != l){ + fprint(2, "FAIL: mpcmp: %L and %L: got %d, expected %d\n", a, b, k, l); + cmpok = 1; + } + ldtomp(a, ma); + ldtomp(b, mb); + l = ldmagcmp(a, b); + k = mpmagcmp(ma, mb); + if(k < 0) k = -1; + if(k > 0) k = 1; + if(k != l){ + fprint(2, "FAIL: mpmagcmp: %L and %L: got %d, expected %d\n", a, b, k, l); + magcmpok = 1; + } + } + ldfree(a); + ldfree(b); + mpfree(ma); + mpfree(mb); + if(cmpok) fprint(2, "mpcmp: passed\n"); + if(magcmpok) fprint(2, "mpmagcmp: passed\n"); +} + void main() { fmtinstall('B', mpfmt); + fmtinstall(L'β', mpdetfmt); fmtinstall('L', ldfmt); + siglo(); + cmptest(); + all1i(); all2(); }