mpc: constant expression folding

This commit is contained in:
cinap_lenrek 2016-02-01 19:27:57 +01:00
parent 340d83d49d
commit 0bfac109a4

View file

@ -28,6 +28,7 @@ struct Node
Node* l;
Node* r;
Sym* s;
mpint* m;
int n;
};
@ -238,11 +239,11 @@ expr:
{
$$ = new('e', $1, $2);
}
| expr LSH num
| expr LSH expr
{
$$ = new(LSH, $1, $3);
}
| expr RSH num
| expr RSH expr
{
$$ = new(RSH, $1, $3);
}
@ -390,6 +391,7 @@ new(int c, Node *l, Node *r)
n->l = l;
n->r = r;
n->s = nil;
n->m = nil;
n->n = lineno;
return n;
}
@ -561,7 +563,7 @@ complex(Node *n)
{
if(n->c == NAME)
return 0;
if(n->c == NUM && strlen(n->s->n) == 1 && atoi(n->s->n) < 3)
if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0)
return 0;
return 1;
}
@ -570,37 +572,131 @@ void
bcom(Node *n, Node *t);
Node*
ecom(Node *f, Node *t)
ccom(Node *f)
{
Node *l, *r, *t2;
mpint *m;
Node *l, *r;
if(f == nil)
return nil;
if(f->c == NUM){
m = strtomp(f->s->n, nil, 0, nil);
if(m == nil)
if(f->m != nil)
return f;
f->m = (void*)~0;
switch(f->c){
case NUM:
f->m = strtomp(f->s->n, nil, 0, nil);
if(f->m == nil)
diag(f, "bad constant");
if(mpcmp(m, mpzero) == 0){
goto out;
case LSH:
case RSH:
break;
case '+':
case '-':
case '*':
case '/':
case '%':
case '^':
if(modulo == nil || modulo->c == NUM)
break;
/* wet floor */
default:
return f;
}
f->l = l = ccom(f->l);
f->r = r = ccom(f->r);
if(l == nil || r == nil || l->c != NUM || r->c != NUM)
return f;
f->m = mpnew(0);
switch(f->c){
case LSH:
case RSH:
if(mpsignif(r->m) > 32)
diag(f, "bad shift");
if(f->c == LSH)
mpleft(l->m, mptoi(r->m), f->m);
else
mpright(l->m, mptoi(r->m), f->m);
goto out;
case '+':
mpadd(l->m, r->m, f->m);
break;
case '-':
mpsub(l->m, r->m, f->m);
break;
case '*':
mpmul(l->m, r->m, f->m);
break;
case '/':
if(modulo != nil){
mpinvert(r->m, modulo->m, f->m);
mpmul(f->m, l->m, f->m);
} else {
mpdiv(l->m, r->m, f->m, nil);
}
break;
case '%':
mpmod(l->m, r->m, f->m);
break;
case '^':
mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m);
goto out;
}
if(modulo != nil)
mpmod(f->m, modulo->m, f->m);
out:
f->l = nil;
f->r = nil;
f->s = nil;
f->c = NUM;
return f;
}
Node*
ecom(Node *f, Node *t)
{
Node *l, *r, *t2;
if(f == nil)
return nil;
f = ccom(f);
if(f->c == NUM){
if(f->m->sign < 0){
f->m->sign = 1;
t = ecom(f, t);
f->m->sign = -1;
if(isconst(t))
t = ecom(t, alloctmp());
cprint("%N->sign = -1;\n", t);
return t;
}
if(mpcmp(f->m, mpzero) == 0){
f->c = NAME;
f->s = sym("mpzero");
f->s->f = FSET;
return ecom(f, t);
}
if(mpcmp(m, mpone) == 0){
if(mpcmp(f->m, mpone) == 0){
f->c = NAME;
f->s = sym("mpone");
f->s->f = FSET;
return ecom(f, t);
}
if(mpcmp(m, mptwo) == 0){
if(mpcmp(f->m, mptwo) == 0){
f->c = NAME;
f->s = sym("mptwo");
f->s->f = FSET;
return ecom(f, t);
}
mpfree(m);
}
if(f->c == ','){
@ -645,24 +741,23 @@ ecom(Node *f, Node *t)
switch(f->c){
case NUM:
m = strtomp(f->s->n, nil, 0, nil);
if(m == nil)
diag(f, "bad constant");
if(mpsignif(m) <= 32)
cprint("uitomp(%udUL, %N);\n", mptoui(m), t);
else if(mpsignif(m) <= 64)
cprint("uvtomp(%lludULL, %N);\n", mptouv(m), t);
if(mpsignif(f->m) <= 32)
cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t);
else if(mpsignif(f->m) <= 64)
cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t);
else
cprint("strtomp(\"%.16B\", nil, 16, %N);\n", m, t);
mpfree(m);
cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t);
goto out;
case LSH:
l = f->l->c == NAME ? f->l : ecom(f->l, t);
cprint("mpleft(%N, %N, %N);\n", l, f->r, t);
goto out;
case RSH:
r = ccom(f->r);
if(r == nil || r->c != NUM || mpsignif(r->m) > 32)
diag(f, "bad shift");
l = f->l->c == NAME ? f->l : ecom(f->l, t);
cprint("mpright(%N, %N, %N);\n", l, f->r, t);
if(f->c == LSH)
cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t);
else
cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t);
goto out;
case '*':
case '/':
@ -670,8 +765,10 @@ ecom(Node *f, Node *t)
r = ecom(f->r, nil);
break;
default:
l = ecom(f->l, complex(f->l) && !symref(f->r, t->s) ? t : nil);
r = ecom(f->r, complex(f->r) && l->s != t->s ? t : nil);
l = ccom(f->l);
r = ccom(f->r);
l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil);
r = ecom(r, complex(r) && l->s != t->s ? t : nil);
break;
}
@ -975,8 +1072,11 @@ Nfmt(Fmt *f)
return fmtprint(f, "%N, %N", n->l, n->r);
switch(n->c){
case NAME:
case NUM:
if(n->m != nil)
return fmtprint(f, "%B", n->m);
/* wet floor */
case NAME:
return fmtprint(f, "%s", n->s->n);
case EQ:
return fmtprint(f, "==");