libmp: add mpfield() function for fast field arithmetic

instead of testing for special field primes each time in mpmod(),
make it explicit with a mpfiled() function that tests a modulus N
to be of some special form that can be reduced more efficiently with
some precalculation, and replaces N with a Mfield* when it can. the
Mfield*'s are recognized by mpmod() as they have the MPfield flag
set and provide a function pointer that executes the fast reduction.
This commit is contained in:
cinap_lenrek 2015-12-16 21:18:20 +01:00
parent b6f04b77e3
commit efd3ac8a23
8 changed files with 337 additions and 107 deletions

View file

@ -8,7 +8,6 @@
* mpdigit must be an atomic type. mpdigit is defined * mpdigit must be an atomic type. mpdigit is defined
* in the architecture specific u.h * in the architecture specific u.h
*/ */
typedef struct mpint mpint; typedef struct mpint mpint;
struct mpint struct mpint
@ -25,6 +24,7 @@ enum
MPstatic= 0x01, /* static constant */ MPstatic= 0x01, /* static constant */
MPnorm= 0x02, /* normalization status */ MPnorm= 0x02, /* normalization status */
MPtimesafe= 0x04, /* request time invariant computation */ MPtimesafe= 0x04, /* request time invariant computation */
MPfield= 0x08, /* this mpint is a field modulus */
Dbytes= sizeof(mpdigit), /* bytes per digit */ Dbytes= sizeof(mpdigit), /* bytes per digit */
Dbits= Dbytes*8 /* bits per digit */ Dbits= Dbytes*8 /* bits per digit */
@ -165,5 +165,18 @@ void crtout(CRTpre*, CRTres*, mpint*); /* convert residues to mpint */
void crtprefree(CRTpre*); void crtprefree(CRTpre*);
void crtresfree(CRTres*); void crtresfree(CRTres*);
/* fast field arithmetic */
typedef struct Mfield Mfield;
struct Mfield
{
mpint;
int (*reduce)(Mfield*, mpint*, mpint*);
};
mpint *mpfield(mpint*);
Mfield *gmfield(mpint*);
Mfield *cnfield(mpint*);
#pragma varargck type "B" mpint* #pragma varargck type "B" mpint*

View file

@ -0,0 +1,114 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
/*
* fast reduction for crandall numbers of the form: 2^n - c
*/
enum {
MAXDIG = 1024 / Dbits,
};
typedef struct CNfield CNfield;
struct CNfield
{
Mfield;
mpint m[1];
int s;
mpdigit c;
};
static int
cnreduce(Mfield *m, mpint *a, mpint *r)
{
mpdigit q[MAXDIG-1], t[MAXDIG], d;
CNfield *f = (CNfield*)m;
int qn, tn, k;
k = f->top;
if((a->top - k) >= MAXDIG)
return -1;
mpleft(a, f->s, r);
if(r->top <= k)
mpbits(r, (k+1)*Dbits);
/* q = hi(r) */
qn = r->top - k;
memmove(q, r->p+k, qn*Dbytes);
/* r = lo(r) */
r->top = k;
r->sign = 1;
do {
/* t = q*c */
tn = qn+1;
memset(t, 0, tn*Dbytes);
mpvecdigmuladd(q, qn, f->c, t);
/* q = hi(t) */
qn = tn - k;
if(qn <= 0) qn = 0;
else memmove(q, t+k, qn*Dbytes);
/* r += lo(t) */
if(tn > k)
tn = k;
mpvecadd(r->p, k, t, tn, r->p);
/* if(r >= m) r -= m */
mpvecsub(r->p, k+1, f->m->p, k, t);
d = t[k];
for(tn = 0; tn < k; tn++)
r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
} while(qn > 0);
if(f->s != 0)
mpright(r, f->s, r);
mpnorm(r);
return 0;
}
Mfield*
cnfield(mpint *N)
{
mpint *M, *C;
CNfield *f;
mpdigit d;
int s;
if(N->top <= 2 || N->top >= MAXDIG)
return nil;
f = nil;
d = N->p[N->top-1];
for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
d <<= 1;
C = mpnew(0);
M = mpcopy(N);
mpleft(N, s, M);
mpleft(mpone, M->top*Dbits, C);
mpsub(C, M, C);
if(C->top != 1)
goto out;
f = mallocz(sizeof(CNfield) + M->top*sizeof(mpdigit), 1);
if(f == nil)
goto out;
f->s = s;
f->c = C->p[0];
f->m->size = M->top;
f->m->p = (mpdigit*)&f[1];
mpassign(M, f->m);
mpassign(N, f);
f->reduce = cnreduce;
f->flags |= MPfield;
out:
mpfree(M);
mpfree(C);
return f;
}

View file

@ -0,0 +1,170 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
/*
* fast reduction for generalized mersenne numbers (GM)
* using a series of additions and subtractions.
*/
enum {
MAXDIG = 1024/Dbits,
};
typedef struct GMfield GMfield;
struct GMfield
{
Mfield;
mpint m2[1];
int nadd;
int nsub;
int indx[256];
};
static int
gmreduce(Mfield *m, mpint *a, mpint *r)
{
GMfield *g = (GMfield*)m;
mpdigit d0, t[MAXDIG];
int i, j, d, *x;
if(mpmagcmp(a, g->m2) >= 0)
return -1;
if(a != r)
mpassign(a, r);
d = g->top;
mpbits(r, (d+1)*Dbits*2);
memmove(t+d, r->p+d, d*Dbytes);
r->sign = 1;
r->top = d;
r->p[d] = 0;
if(g->nsub > 0)
mpvecdigmuladd(g->p, d, g->nsub, r->p);
x = g->indx;
for(i=0; i<g->nadd; i++){
t[0] = 0;
d0 = t[*x++];
for(j=1; j<d; j++)
t[j] = t[*x++];
t[0] = d0;
mpvecadd(r->p, d+1, t, d, r->p);
}
for(i=0; i<g->nsub; i++){
t[0] = 0;
d0 = t[*x++];
for(j=1; j<d; j++)
t[j] = t[*x++];
t[0] = d0;
mpvecsub(r->p, d+1, t, d, r->p);
}
mpvecdigmulsub(g->p, d, r->p[d], r->p);
r->p[d] = 0;
mpvecsub(r->p, d+1, g->p, d, r->p+d+1);
d0 = r->p[2*d+1];
for(j=0; j<d; j++)
r->p[j] = (r->p[j] & d0) | (r->p[j+d+1] & ~d0);
mpnorm(r);
return 0;
}
Mfield*
gmfield(mpint *N)
{
int i,j,d, s, *C, *X, *x, *e;
mpint *M, *T;
GMfield *g;
d = N->top;
if(d <= 2 || d > MAXDIG/2 || (mpsignif(N) % Dbits) != 0)
return nil;
g = nil;
T = mpnew(0);
M = mpcopy(N);
C = malloc(sizeof(int)*(d+1));
X = malloc(sizeof(int)*(d*d));
for(i=0; i<=d; i++){
if((M->p[i]>>8) != 0 && (~M->p[i]>>8) != 0)
goto out;
j = M->p[i];
C[d - i] = -j;
itomp(j, T);
mpleft(T, i*Dbits, T);
mpsub(M, T, M);
}
for(j=0; j<d; j++)
X[j] = C[d-j];
for(i=1; i<d; i++){
X[d*i] = X[d*(i-1) + d-1]*C[d];
for(j=1; j<d; j++)
X[d*i + j] = X[d*(i-1) + j-1] + X[d*(i-1) + d-1]*C[d-j];
}
g = mallocz(sizeof(GMfield) + (d+1)*sizeof(mpdigit)*2, 1);
if(g == nil)
goto out;
g->m2->p = (mpdigit*)&g[1];
g->m2->size = d*2+1;
mpmul(N, N, g->m2);
mpassign(N, g);
g->reduce = gmreduce;
g->flags |= MPfield;
s = 0;
x = g->indx;
e = x + nelem(g->indx) - d;
for(g->nadd=0; x <= e; x += d, g->nadd++){
s = 0;
for(i=0; i<d; i++){
for(j=0; j<d; j++){
if(X[d*i+j] > 0 && x[j] == 0){
X[d*i+j]--;
x[j] = d+i;
s = 1;
break;
}
}
}
if(s == 0)
break;
}
for(g->nsub=0; x <= e; x += d, g->nsub++){
s = 0;
for(i=0; i<d; i++){
for(j=0; j<d; j++){
if(X[d*i+j] < 0 && x[j] == 0){
X[d*i+j]++;
x[j] = d+i;
s = 1;
break;
}
}
}
if(s == 0)
break;
}
if(s != 0){
mpfree(g);
g = nil;
}
out:
free(C);
free(X);
mpfree(M);
mpfree(T);
return g;
}

View file

@ -38,6 +38,9 @@ FILES=\
mptoui\ mptoui\
mptov\ mptov\
mptouv\ mptouv\
mpfield\
cnfield\
gmfield\
mplogic\ mplogic\
ALLOFILES=${FILES:%=%.$O} ALLOFILES=${FILES:%=%.$O}

View file

@ -137,7 +137,7 @@ mpcopy(mpint *old)
setmalloctag(new, getcallerpc(&old)); setmalloctag(new, getcallerpc(&old));
new->sign = old->sign; new->sign = old->sign;
new->top = old->top; new->top = old->top;
new->flags = old->flags & ~MPstatic; new->flags = old->flags & ~(MPstatic|MPfield);
memmove(new->p, old->p, Dbytes*old->top); memmove(new->p, old->p, Dbytes*old->top);
return new; return new;
} }
@ -152,7 +152,7 @@ mpassign(mpint *old, mpint *new)
new->sign = old->sign; new->sign = old->sign;
new->top = old->top; new->top = old->top;
new->flags &= ~MPnorm; new->flags &= ~MPnorm;
new->flags |= old->flags & ~MPstatic; new->flags |= old->flags & ~(MPstatic|MPfield);
memmove(new->p, old->p, Dbytes*old->top); memmove(new->p, old->p, Dbytes*old->top);
} }

View file

@ -61,25 +61,23 @@ mpexp(mpint *b, mpint *e, mpint *m, mpint *res)
j = 0; j = 0;
for(;;){ for(;;){
for(; bit != 0; bit >>= 1){ for(; bit != 0; bit >>= 1){
mpmul(t[j], t[j], t[j^1]); if(m != nil)
if(bit & d) mpmodmul(t[j], t[j], m, t[j^1]);
mpmul(t[j^1], b, t[j]);
else else
mpmul(t[j], t[j], t[j^1]);
if(bit & d) {
if(m != nil)
mpmodmul(t[j^1], b, m, t[j]);
else
mpmul(t[j^1], b, t[j]);
} else
j ^= 1; j ^= 1;
if(m != nil && t[j]->top > m->top){
mpmod(t[j], m, t[j^1]);
j ^= 1;
}
} }
if(--i < 0) if(--i < 0)
break; break;
bit = mpdighi; bit = mpdighi;
d = e->p[i]; d = e->p[i];
} }
if(m != nil){
mpmod(t[j], m, t[j^1]);
j ^= 1;
}
if(t[j] == res){ if(t[j] == res){
mpfree(t[j^1]); mpfree(t[j^1]);
} else { } else {

View file

@ -0,0 +1,21 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
mpint*
mpfield(mpint *N)
{
Mfield *f;
if(N == nil || N->flags & (MPfield|MPstatic))
return N;
if((f = cnfield(N)) != nil)
goto Exchange;
if((f = gmfield(N)) != nil)
goto Exchange;
return N;
Exchange:
setmalloctag(f, getcallerpc(&N));
mpfree(N);
return f;
}

View file

@ -5,101 +5,12 @@
void void
mpmod(mpint *x, mpint *n, mpint *r) mpmod(mpint *x, mpint *n, mpint *r)
{ {
static int busy; int sign;
static mpint *p, *m, *c, *v;
mpdigit q[32], t[64], d;
int sign, k, s, qn, tn;
sign = x->sign; sign = x->sign;
if((n->flags & MPfield) == 0
assert(n->flags & MPnorm); || ((Mfield*)n)->reduce((Mfield*)n, x, r) != 0)
if(n->top <= 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q)) mpdiv(x, n, nil, r);
goto hard;
/*
* check if n = 2**k - c where c has few power of two factors
* above the lowest digit.
*/
for(k = n->top-1; k > 0; k--){
d = n->p[k] >> 1;
if((d+1 & d) != 0)
goto hard;
}
d = n->p[n->top-1];
for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
d <<= 1;
/* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
k = n->top;
while(_tas(&busy))
;
if(p == nil || mpmagcmp(n, p) != 0){
if(m == nil){
m = mpnew(0);
c = mpnew(0);
p = mpnew(0);
}
mpleft(n, s, m);
mpleft(mpone, k*Dbits, c);
mpsub(c, m, c);
if(c->top >= k){
mpassign(mpzero, p);
busy = 0;
goto hard;
}
mpassign(n, p);
}
mpleft(x, s, r);
if(r->top <= k){
mpbits(r, (k+1)*Dbits);
r->top = k+1;
}
/* q = hi(r) */
qn = r->top - k;
memmove(q, r->p+k, qn*Dbytes);
/* r = lo(r) */
r->top = k;
do {
/* t = q*c */
tn = qn + c->top;
memset(t, 0, tn*Dbytes);
mpvecmul(q, qn, c->p, c->top, t);
/* q = hi(t) */
qn = tn - k;
if(qn <= 0) qn = 0;
else memmove(q, t+k, qn*Dbytes);
/* r += lo(t) */
if(tn > k)
tn = k;
mpvecadd(r->p, k, t, tn, r->p);
/* if(r >= m) r -= m */
mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
for(tn = 0; tn < k; tn++)
r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
} while(qn > 0);
busy = 0;
if(s != 0)
mpright(r, s, r);
else
mpnorm(r);
goto done;
hard:
mpdiv(x, n, nil, r);
done:
if(sign < 0) if(sign < 0)
mpmagsub(n, r, r); mpmagsub(n, r, r);
} }