plan9fox/sys/src/libsat/satadd.c

247 lines
4.8 KiB
C

#include <u.h>
#include <libc.h>
#include <sat.h>
#include "impl.h"
static SATBlock *
newblock(SATSolve *s, int learned)
{
SATBlock *b;
b = calloc(1, SATBLOCKSZ);
if(b == nil)
saterror(s, "malloc: %r");
b->prev = s->bl[learned].prev;
b->next = &s->bl[learned];
b->next->prev = b;
b->prev->next = b;
b->end = (void*) b->data;
return b;
}
SATClause *
satnewclause(SATSolve *s, int n, int learned)
{
SATBlock *b;
SATClause *c;
int f, sz;
sz = sizeof(SATClause) + (n - 1) * sizeof(int);
assert(sz <= SATBLOCKSZ);
if(learned)
b = s->lastbl;
else
b = s->bl[0].prev;
for(;;){
f = (uchar*)b + SATBLOCKSZ - (uchar*)b->end;
if(f >= sz) break;
b = b->next;
if(b == &s->bl[learned])
b = newblock(s, learned);
}
c = b->end;
memset(c, 0, sizeof(SATClause));
b->end = (void *)((uintptr)b->end + sz + CLAUSEALIGN - 1 & -CLAUSEALIGN);
b->last = c;
if(learned){
if(s->lastp[1] == &s->learncl)
*s->lastp[0] = c;
s->lastbl = b;
}else
c->next = s->learncl;
*s->lastp[learned] = c;
s->lastp[learned] = &c->next;
s->ncl++;
return c;
}
/* this is currently only used to subsume clauses, i.e. n is guaranteed to be less than the last n */
SATClause *
satreplclause(SATSolve *s, int n)
{
SATBlock *b;
SATClause *c, **wp;
int f, sz, i, l;
assert(s->lastbl != nil && s->lastbl->last != nil);
b = s->lastbl;
c = b->last;
f = (uchar*)b + SATBLOCKSZ - (uchar*)c;
sz = sizeof(SATClause) + (n - 1) * sizeof(int);
assert(f >= sz);
b->end = (void *)((uintptr)c + sz + CLAUSEALIGN - 1 & -CLAUSEALIGN);
for(i = 0; i < 2; i++){
l = c->l[i];
for(wp = &s->lit[l].watch; *wp != nil && *wp != c; wp = &(*wp)->watch[(*wp)->l[1] == l])
;
assert(*wp != nil);
*wp = c->watch[i];
}
memset(c, 0, sizeof(SATClause));
return c;
}
static int
litconv(SATSolve *s, int l)
{
int v, m, n;
SATVar *vp;
SATLit *lp;
m = l >> 31;
v = (l + m ^ m) - 1;
if(v >= s->nvaralloc){
n = -(-(v+1) & -SATVARALLOC);
s->var = vp = satrealloc(s, s->var, n * sizeof(SATVar));
s->lit = lp = satrealloc(s, s->lit, 2 * n * sizeof(SATLit));
memset(vp += s->nvaralloc, 0, (n - s->nvaralloc) * sizeof(SATVar));
memset(lp += 2*s->nvaralloc, 0, 2 * (n - s->nvaralloc) * sizeof(SATLit));
for(; vp < s->var + n; vp++){
vp->lvl = -1;
vp->flags = VARPHASE;
}
for(; lp < s->lit + 2 * n; lp++)
lp->val = -1;
s->nvaralloc = n;
}
if(v >= s->nvar)
s->nvar = v + 1;
return v << 1 | m & 1;
}
static void
addbimp(SATSolve *s, int l0, int l1)
{
SATLit *lp;
lp = &s->lit[NOT(l0)];
lp->bimp = satrealloc(s, lp->bimp, (lp->nbimp + 1) * sizeof(int));
lp->bimp[lp->nbimp++] = l1;
}
static SATSolve *
satadd1special(SATSolve *s, int *a, int n)
{
int i, l0, l1;
if(n == 0){
s->unsat = 1;
return s;
}
l0 = a[0];
l1 = 0;
for(i = 1; i < n; i++)
if(a[i] != l0){
l1 = a[i];
break;
}
if(l1 == 0){
l0 = litconv(s, l0);
assert(s->lvl == 0);
switch(s->lit[l0].val){
case 0:
s->unsat = 1;
return s;
case -1:
s->trail = satrealloc(s, s->trail, sizeof(int) * s->nvar);
memmove(&s->trail[1], s->trail, sizeof(int) * s->ntrail);
s->trail[0] = l0;
s->ntrail++;
s->var[VAR(l0)].flags |= VARUSER;
s->var[VAR(l0)].lvl = 0;
s->lit[l0].val = 1;
s->lit[NOT(l0)].val = 0;
}
return s;
}
if(l0 + l1 == 0) return s;
l0 = litconv(s, l0);
l1 = litconv(s, l1);
addbimp(s, l0, l1);
addbimp(s, l1, l0);
return s;
}
SATSolve *
satadd1(SATSolve *s, int *a, int n)
{
SATClause *c;
int i, j, l, u;
SATVar *v;
if(s == nil){
s = satnew();
if(s == nil)
saterror(nil, "satnew: %r");
}
if(n < 0)
for(n = 0; a[n] != 0; n++)
;
for(i = 0; i < n; i++)
if(a[i] == 0)
saterror(s, "satadd1(%p, %p, %d): a[%d]==0, callerpc=%p", s, a, n, i, getcallerpc(&s));
satbackjump(s, 0);
if(n <= 2)
return satadd1special(s, a, n);
/* use stamps to detect repeated literals and tautological clauses */
if(s->stamp >= (uint)-6){
for(i = 0; i < s->nvar; i++)
s->var[i].stamp = 0;
s->stamp = 1;
}else
s->stamp += 3;
u = 0;
for(i = 0; i < n; i++){
l = litconv(s, a[i]);
v = &s->var[VAR(l)];
if(v->stamp < s->stamp) u++;
if(v->stamp == s->stamp + (~l & 1))
return s; /* tautological */
v->stamp = s->stamp + (l & 1);
}
if(u <= 2)
return satadd1special(s, a, n);
s->stamp += 3;
c = satnewclause(s, u, 0);
c->n = u;
for(i = 0, j = 0; i < n; i++){
l = litconv(s, a[i]);
v = &s->var[VAR(l)];
if(v->stamp < s->stamp){
c->l[j++] = l;
v->stamp = s->stamp;
}
}
assert(j == u);
s->ncl0++;
return s;
}
void
satvafix(va_list va)
{
int *d;
uintptr *s;
if(sizeof(int)==sizeof(uintptr)) return;
d = (int *) va;
s = (uintptr *) va;
do
*d++ = *s;
while((int)*s++ != 0);
}
SATSolve *
sataddv(SATSolve *s, ...)
{
va_list va;
va_start(va, s);
/* horrible hack */
satvafix(va);
s = satadd1(s, (int*)va, -1);
va_end(va);
return s;
}