#include #include #include #include #include #include "dat.h" #include "fns.h" Node * icast(int sign, int size, Node *n) { Type *t; t = type(TYPINT, sign, size); return node(OCAST, t, n); } /* the type checker checks types. the result is an expression that is correct if evaluated with 64-bit operands all the way. to maintain c-like semantics, this means adding casts all over the place, which will get optimised later. note we use kencc, NOT ansi c, semantics for unsigned. */ Node * typecheck(Node *n) { int s1, s2, sign; switch(/*nodetype*/n->type){ case OSYM: switch(n->sym->type){ case SYMNONE: error("undeclared '%s'", n->sym->name); break; case SYMVAR: n->typ = n->sym->typ; break; default: sysfatal("typecheck: unknown symbol type %d", n->sym->type); } break; case ONUM: if((vlong)n->num >= -0x80000000LL && (vlong)n->num <= 0x7fffffffLL) n->typ = type(TYPINT, 4, 1); else n->typ = type(TYPINT, 8, 1); break; case OSTR: n->typ = type(TYPSTRING); break; case OBIN: n->n1 = typecheck(n->n1); n->n2 = typecheck(n->n2); if(n->n1->typ == nil || n->n2->typ == nil) break; if(n->n1->typ->type != TYPINT){ error("%τ not allowed in operation", n->n1->typ); break; } if(n->n2->typ->type != TYPINT){ error("%τ not allowed in operation", n->n2->typ); break; } s1 = n->n1->typ->size; s2 = n->n2->typ->size; sign = n->n1->typ->sign && n->n2->typ->sign; switch(n->op){ case OPADD: case OPSUB: case OPMUL: case OPDIV: case OPMOD: case OPAND: case OPOR: case OPXOR: case OPXNOR: n->typ = type(TYPINT, 8, sign); if(s1 > 4 || s2 > 4){ n->n1 = icast(8, sign, n->n1); n->n2 = icast(8, sign, n->n2); return n; }else{ n->n1 = icast(4, sign, n->n1); n->n2 = icast(4, sign, n->n2); return icast(4, sign, n); } case OPEQ: case OPNE: case OPLT: case OPLE: n->typ = type(TYPINT, 4, sign); if(s1 > 4 || s2 > 4){ n->n1 = icast(8, sign, n->n1); n->n2 = icast(8, sign, n->n2); return n; }else{ n->n1 = icast(4, sign, n->n1); n->n2 = icast(4, sign, n->n2); return n; } case OPLAND: case OPLOR: n->typ = type(TYPINT, 4, sign); return n; case OPLSH: case OPRSH: if(n->n1->typ->size <= 4) n->n1 = icast(4, n->n1->typ->sign, n->n1); n->typ = n->n1->typ; return icast(n->typ->size, n->typ->sign, n); default: sysfatal("typecheck: unknown op %d", n->op); } break; case OCAST: n->n1 = typecheck(n->n1); if(n->n1->typ == nil) break; if(n->typ->type == TYPINT && n->n1->typ->type == TYPINT){ }else if(n->typ == n->n1->typ){ }else if(n->typ->type == TYPSTRING && n->n1->typ->type == TYPINT){ }else error("can't cast from %τ to %τ", n->n1->typ, n->typ); break; case OLNOT: n->n1 = typecheck(n->n1); if(n->n1->typ == nil) break; if(n->n1->typ->type != TYPINT){ error("%τ not allowed in operation", n->n1->typ); break; } n->typ = type(TYPINT, 4, 1); break; case OTERN: n->n1 = typecheck(n->n1); n->n2 = typecheck(n->n2); n->n3 = typecheck(n->n3); if(n->n1->typ == nil || n->n2->typ == nil || n->n3->typ == nil) break; if(n->n1->typ->type != TYPINT){ error("%τ not allowed in operation", n->n1->typ); break; } if(n->n2->typ->type == TYPINT || n->n3->typ->type == TYPINT){ sign = n->n2->typ->sign && n->n3->typ->sign; s1 = n->n2->typ->size; s2 = n->n3->typ->size; if(s1 > 4 || s2 > 4){ n->n2 = icast(8, sign, n->n2); n->n3 = icast(8, sign, n->n3); n->typ = type(TYPINT, 8, sign); return n; }else{ n->n2 = icast(4, sign, n->n2); n->n3 = icast(4, sign, n->n3); n->typ = type(TYPINT, 4, sign); return n; } }else if(n->n2->typ == n->n3->typ){ n->typ = n->n2->typ; }else error("don't know how to do ternary with %τ and %τ", n->n2->typ, n->n3->typ); break; case ORECORD: default: sysfatal("typecheck: unknown node type %α", n->type); } return n; } vlong evalop(int op, int sign, vlong v1, vlong v2) { switch(/*oper*/op){ case OPADD: return v1 + v2; break; case OPSUB: return v1 - v2; break; case OPMUL: return v1 * v2; break; case OPDIV: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 / v2 : (uvlong)v1 / (uvlong)v2; break; case OPMOD: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 % v2 : (uvlong)v1 % (uvlong)v2; break; case OPAND: return v1 & v2; break; case OPOR: return v1 | v2; break; case OPXOR: return v1 ^ v2; break; case OPXNOR: return ~(v1 ^ v2); break; case OPLSH: if((u64int)v2 >= 64) return 0; else return v1 << v2; break; case OPRSH: if(sign){ if((u64int)v2 >= 64) return v1 >> 63; else return v1 >> v2; }else{ if((u64int)v2 >= 64) return 0; else return (u64int)v1 >> v2; } break; case OPEQ: return v1 == v2; break; case OPNE: return v1 != v2; break; case OPLT: return v1 < v2; break; case OPLE: return v1 <= v2; break; case OPLAND: return v1 && v2; break; case OPLOR: return v1 || v2; break; default: sysfatal("cfold: unknown op %.2x", op); return 0; } } Node * addtype(Type *t, Node *n) { n->typ = t; return n; } /* fold constants */ static Node * cfold(Node *n) { switch(/*nodetype*/n->type){ case ONUM: case OSYM: case OSTR: return n; case OBIN: n->n1 = cfold(n->n1); n->n2 = cfold(n->n2); if(n->n1->type != ONUM || n->n2->type != ONUM) return n; return addtype(n->typ, node(ONUM, evalop(n->op, n->typ->sign, n->n1->num, n->n2->num))); case OLNOT: n->n1 = cfold(n->n1); if(n->n1->type == ONUM) return addtype(n->typ, node(ONUM, !n->n1->num)); return n; case OTERN: n->n1 = cfold(n->n1); n->n2 = cfold(n->n2); n->n3 = cfold(n->n3); if(n->n1->type == ONUM) return n->n1->num ? n->n2 : n->n3; return n; case OCAST: n->n1 = cfold(n->n1); if(n->n1->type != ONUM || n->typ->type != TYPINT) return n; switch(n->typ->size << 4 | n->typ->sign){ case 0x10: return addtype(n->typ, node(ONUM, (vlong)(u8int)n->n1->num)); case 0x11: return addtype(n->typ, node(ONUM, (vlong)(s8int)n->n1->num)); case 0x20: return addtype(n->typ, node(ONUM, (vlong)(u16int)n->n1->num)); case 0x21: return addtype(n->typ, node(ONUM, (vlong)(s16int)n->n1->num)); case 0x40: return addtype(n->typ, node(ONUM, (vlong)(u32int)n->n1->num)); case 0x41: return addtype(n->typ, node(ONUM, (vlong)(s32int)n->n1->num)); case 0x80: return addtype(n->typ, node(ONUM, n->n1->num)); case 0x81: return addtype(n->typ, node(ONUM, n->n1->num)); } return n; case ORECORD: default: fprint(2, "cfold: unknown type %α\n", n->type); return n; } } /* calculate the minimum record size for each node of the expression */ static Node * calcrecsize(Node *n) { switch(/*nodetype*/n->type){ case ONUM: case OSTR: n->recsize = 0; break; case OSYM: switch(n->sym->type){ case SYMVAR: switch(n->sym->idx){ case DTV_TIME: case DTV_PROBE: n->recsize = 0; break; default: n->recsize = n->typ->size; break; } break; default: sysfatal("calcrecsize: unknown symbol type %d", n->sym->type); return nil; } break; case OBIN: n->n1 = calcrecsize(n->n1); n->n2 = calcrecsize(n->n2); n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize); break; case OLNOT: n->n1 = calcrecsize(n->n1); n->recsize = min(n->typ->size, n->n1->recsize); break; case OCAST: n->n1 = calcrecsize(n->n1); if(n->typ->type == TYPSTRING) n->recsize = n->typ->size; else n->recsize = min(n->typ->size, n->n1->recsize); break; case OTERN: n->n1 = calcrecsize(n->n1); n->n2 = calcrecsize(n->n2); n->n3 = calcrecsize(n->n3); n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize + n->n3->recsize); break; case ORECORD: default: sysfatal("calcrecsize: unknown type %α", n->type); return nil; } return n; } /* insert ORECORD nodes to mark the subexpression that we will pass to the kernel */ static Node * insrecord(Node *n) { if(n->recsize == 0) return n; if(n->typ->size == n->recsize) return addtype(n->typ, node(ORECORD, n)); switch(/*nodetype*/n->type){ case ONUM: case OSTR: case OSYM: break; case OBIN: n->n1 = insrecord(n->n1); n->n2 = insrecord(n->n2); break; case OLNOT: case OCAST: n->n1 = insrecord(n->n1); break; case OTERN: n->n1 = insrecord(n->n1); n->n2 = insrecord(n->n2); n->n3 = insrecord(n->n3); break; case ORECORD: default: sysfatal("insrecord: unknown type %α", n->type); return nil; } return n; } /* delete useless casts. going down we determine the number of bits (m) needed to be correct at each stage. going back up we determine the number of bits (n->databits) which can be either 0 or 1. all other bits are either zero (n->upper == UPZX) or sign-extended (n->upper == UPSX). note that by number of bits we always mean a consecutive block starting from the LSB. we can delete a cast if it either affects only bits not needed (according to m) or if it's a no-op (according to databits, upper). */ static Node * elidecasts(Node *n, int m) { switch(/*nodetype*/n->type){ case OSTR: return n; case ONUM: n->databits = n->typ->size * 8; n->upper = n->typ->sign ? UPSX : UPZX; break; case OSYM: /* TODO: make less pessimistic */ n->databits = 64; break; case OBIN: switch(/*oper*/n->op){ case OPADD: case OPSUB: n->n1 = elidecasts(n->n1, m); n->n2 = elidecasts(n->n2, m); n->databits = min(64, max(n->n1->databits, n->n2->databits) + 1); n->upper = n->n1->upper | n->n2->upper; break; case OPMUL: n->n1 = elidecasts(n->n1, m); n->n2 = elidecasts(n->n2, m); n->databits = min(64, n->n1->databits + n->n2->databits); n->upper = n->n1->upper | n->n2->upper; break; case OPAND: case OPOR: case OPXOR: case OPXNOR: n->n1 = elidecasts(n->n1, m); n->n2 = elidecasts(n->n2, m); if(n->op == OPAND && (n->n1->upper == UPZX || n->n2->upper == UPZX)){ n->upper = UPZX; if(n->n1->upper == UPZX && n->n2->upper == UPZX) n->databits = min(n->n1->databits, n->n2->databits); else if(n->n1->upper == UPZX) n->databits = n->n1->databits; else n->databits = n->n2->databits; }else{ n->databits = max(n->n1->databits, n->n2->databits); n->upper = n->n1->upper | n->n2->upper; } break; case OPLSH: n->n1 = elidecasts(n->n1, m); n->n2 = elidecasts(n->n2, 64); if(n->n2->type == ONUM && n->n2->num >= 0 && n->n1->databits + (uvlong)n->n2->num <= 64) n->databits = n->n1->databits + n->n2->num; else n->databits = 64; n->upper = n->n1->upper; break; case OPRSH: n->n1 = elidecasts(n->n1, 64); n->n2 = elidecasts(n->n2, 64); if(n->n1->upper == n->typ->sign){ n->databits = n->n1->databits; n->upper = n->n1->upper; }else{ n->databits = 64; n->upper = UPZX; } break; case OPEQ: case OPNE: case OPLT: case OPLE: case OPLAND: case OPLOR: n->n1 = elidecasts(n->n1, 64); n->n2 = elidecasts(n->n2, 64); n->databits = 1; n->upper = UPZX; break; case OPDIV: case OPMOD: default: n->n1 = elidecasts(n->n1, 64); n->n2 = elidecasts(n->n2, 64); n->databits = 64; n->upper = UPZX; break; } break; case OLNOT: n->n1 = elidecasts(n->n1, 64); n->databits = 1; n->upper = UPZX; break; case OCAST: switch(n->typ->type){ case TYPINT: n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m)); if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){ n->databits = n->n1->databits; n->upper = n->n1->upper; }else{ n->databits = n->typ->size * 8; n->upper = n->typ->sign ? UPSX : UPZX; } if(n->typ->size * 8 >= m) return n->n1; if(n->typ->size * 8 >= n->n1->databits && n->typ->sign == n->n1->upper) return n->n1; if(n->typ->size * 8 > n->n1->databits && n->typ->sign && !n->n1->upper) return n->n1; break; case TYPSTRING: n->n1 = elidecasts(n->n1, 64); break; default: sysfatal("elidecasts: don't know how to cast %τ to %τ", n->n1->typ, n->typ); } break; case ORECORD: n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m)); if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){ n->databits = n->n1->databits; n->upper = n->n1->upper; }else{ n->databits = n->typ->size * 8; n->upper = n->typ->sign ? UPSX : UPZX; } break; case OTERN: n->n1 = elidecasts(n->n1, 64); n->n2 = elidecasts(n->n2, m); n->n3 = elidecasts(n->n3, m); if(n->n2->upper == n->n3->upper){ n->databits = max(n->n2->databits, n->n3->databits); n->upper = n->n2->upper; }else{ if(n->n3->upper == UPSX) n->databits = max(min(64, n->n2->databits + 1), n->n3->databits); else n->databits = max(min(64, n->n3->databits + 1), n->n2->databits); n->upper = UPSX; } break; default: sysfatal("elidecasts: unknown type %α", n->type); } // print("need %d got %d%c %ε\n", n->needbits, n->databits, "ZS"[n->upper], n); return n; } Node * exprcheck(Node *n, int pred) { if(dflag) print("start %ε\n", n); n = typecheck(n); if(errors) return n; if(dflag) print("typecheck %ε\n", n); n = cfold(n); if(dflag) print("cfold %ε\n", n); if(!pred){ n = insrecord(calcrecsize(n)); if(dflag) print("insrecord %ε\n", n); } n = elidecasts(n, 64); if(dflag) print("elidecasts %ε\n", n); return n; }