add tests, fix some algorithms
This commit is contained in:
parent
3f144c5c3b
commit
10eaeefb54
130
src/bdd.ml
130
src/bdd.ml
@ -37,30 +37,21 @@ module Hash = struct
|
||||
let hash b = b.tag
|
||||
end
|
||||
|
||||
module Hash2 = struct
|
||||
type t = hidden * hidden
|
||||
|
||||
let equal (x1, y1) (x2, y2) = x1 == x2 && y1 == y2
|
||||
|
||||
let hash (x, y) = x.tag + (19 * y.tag)
|
||||
end
|
||||
|
||||
let hc = Hbdd.hashcons
|
||||
|
||||
let view x = x.node
|
||||
|
||||
module Mem = Memo.MakeWeak (Hash)
|
||||
module Mem2 = Memo.MakeWeak (Hash)
|
||||
|
||||
let true_bdd = hc True
|
||||
|
||||
let false_bdd = hc False
|
||||
|
||||
let get_order bdd =
|
||||
match view bdd with True | False -> max_int | Node (v, _, _) -> v
|
||||
match view bdd with True | False -> -1 | Node (v, _, _) -> v
|
||||
|
||||
let node v l h =
|
||||
if v >= get_order l || v >= get_order h then invalid_arg "node" ;
|
||||
if v <= get_order l || v <= get_order h then invalid_arg "node" ;
|
||||
if Hash.equal l h then l else hc (Node (v, l, h))
|
||||
|
||||
let var_bdd v = node v false_bdd true_bdd
|
||||
@ -72,7 +63,14 @@ let rec fprintf fmt bdd =
|
||||
| False ->
|
||||
Format.fprintf fmt "false"
|
||||
| Node (v, l, h) ->
|
||||
Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf l fprintf h
|
||||
Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf h fprintf l
|
||||
|
||||
let to_string bdd =
|
||||
let buff = Buffer.create 512 in
|
||||
let fmt = Format.formatter_of_buffer buff in
|
||||
fprintf fmt bdd ;
|
||||
Format.pp_print_flush fmt () ;
|
||||
Buffer.contents buff
|
||||
|
||||
let of_bool = function true -> true_bdd | false -> false_bdd
|
||||
|
||||
@ -88,55 +86,53 @@ let neg x =
|
||||
node var (neg low) (neg high))
|
||||
x
|
||||
|
||||
let comb_comm op x y =
|
||||
Mem2.memo
|
||||
(fun comb_comm n1 n2 ->
|
||||
match (view n1, view n2) with
|
||||
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
|
||||
node v1 (comb_comm l1 l2) (comb_comm h1 h2)
|
||||
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
|
||||
node v1 (comb_comm l1 n2) (comb_comm h1 n2)
|
||||
| Node (_, _, _), Node (v2, l2, h2) ->
|
||||
node v2 (comb_comm n1 l2) (comb_comm n1 h2)
|
||||
| True, Node (v, l, h) | Node (v, l, h), True ->
|
||||
node v (comb_comm l true_bdd) (comb_comm h true_bdd)
|
||||
| False, Node (v, l, h) | Node (v, l, h), False ->
|
||||
node v (comb_comm l false_bdd) (comb_comm h false_bdd)
|
||||
| False, False ->
|
||||
of_bool (op false false)
|
||||
| False, True | True, False ->
|
||||
of_bool (op true false)
|
||||
| True, True ->
|
||||
of_bool (op true true))
|
||||
x y
|
||||
(* TODO: memo 2 ? *)
|
||||
let rec comb_comm op n1 n2 =
|
||||
let comb_comm = comb_comm op in
|
||||
match (view n1, view n2) with
|
||||
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
|
||||
node v1 (comb_comm l1 l2) (comb_comm h1 h2)
|
||||
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
|
||||
node v1 (comb_comm l1 n2) (comb_comm h1 n2)
|
||||
| Node (_, _, _), Node (v2, l2, h2) ->
|
||||
node v2 (comb_comm n1 l2) (comb_comm n1 h2)
|
||||
| True, Node (v, l, h) | Node (v, l, h), True ->
|
||||
node v (comb_comm l true_bdd) (comb_comm h true_bdd)
|
||||
| False, Node (v, l, h) | Node (v, l, h), False ->
|
||||
node v (comb_comm l false_bdd) (comb_comm h false_bdd)
|
||||
| False, False ->
|
||||
of_bool (op false false)
|
||||
| False, True | True, False ->
|
||||
of_bool (op true false)
|
||||
| True, True ->
|
||||
of_bool (op true true)
|
||||
|
||||
let comb op x y =
|
||||
Mem2.memo
|
||||
(fun comb n1 n2 ->
|
||||
match (view n1, view n2) with
|
||||
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
|
||||
node v1 (comb l1 l2) (comb h1 h2)
|
||||
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
|
||||
node v1 (comb l1 n2) (comb h1 n2)
|
||||
| Node (_, _, _), Node (v2, l2, h2) ->
|
||||
node v2 (comb n1 l2) (comb n1 h2)
|
||||
| True, Node (v, l, h) ->
|
||||
node v (comb true_bdd l) (comb true_bdd h)
|
||||
| Node (v, l, h), True ->
|
||||
node v (comb l true_bdd) (comb h true_bdd)
|
||||
| False, Node (v, l, h) ->
|
||||
node v (comb false_bdd l) (comb false_bdd h)
|
||||
| Node (v, l, h), False ->
|
||||
node v (comb l false_bdd) (comb h false_bdd)
|
||||
| False, False ->
|
||||
of_bool (op false false)
|
||||
| False, True ->
|
||||
of_bool (op false true)
|
||||
| True, False ->
|
||||
of_bool (op true false)
|
||||
| True, True ->
|
||||
of_bool (op true true))
|
||||
x y
|
||||
(* TODO: memo2 ? *)
|
||||
let rec comb op n1 n2 =
|
||||
let comb = comb op in
|
||||
match (view n1, view n2) with
|
||||
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
|
||||
node v1 (comb l1 l2) (comb h1 h2)
|
||||
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
|
||||
node v1 (comb l1 n2) (comb h1 n2)
|
||||
| Node (_, _, _), Node (v2, l2, h2) ->
|
||||
node v2 (comb n1 l2) (comb n1 h2)
|
||||
| True, Node (v, l, h) ->
|
||||
node v (comb true_bdd l) (comb true_bdd h)
|
||||
| Node (v, l, h), True ->
|
||||
node v (comb l true_bdd) (comb h true_bdd)
|
||||
| False, Node (v, l, h) ->
|
||||
node v (comb false_bdd l) (comb false_bdd h)
|
||||
| Node (v, l, h), False ->
|
||||
node v (comb l false_bdd) (comb h false_bdd)
|
||||
| False, False ->
|
||||
of_bool (op false false)
|
||||
| False, True ->
|
||||
of_bool (op false true)
|
||||
| True, False ->
|
||||
of_bool (op true false)
|
||||
| True, True ->
|
||||
of_bool (op true true)
|
||||
|
||||
let conj = comb_comm (fun x y -> x && y)
|
||||
|
||||
@ -156,7 +152,7 @@ let compute tbl =
|
||||
| Node (v, l, h) -> (
|
||||
match Hashtbl.find tbl v with
|
||||
| exception Not_found ->
|
||||
node v (compute_aux h) (compute_aux l)
|
||||
node v (compute_aux l) (compute_aux h)
|
||||
| true ->
|
||||
compute_aux h
|
||||
| false ->
|
||||
@ -188,10 +184,7 @@ let size =
|
||||
|
||||
let is_sat bdd = match view bdd with False -> false | _ -> true
|
||||
|
||||
let count_sat maxn =
|
||||
let get_var bdd =
|
||||
match view bdd with False | True -> maxn | Node (v, _, _) -> v
|
||||
in
|
||||
let count_sat card =
|
||||
let count =
|
||||
Mem.memo (fun count bdd ->
|
||||
match view bdd with
|
||||
@ -200,10 +193,11 @@ let count_sat maxn =
|
||||
| True ->
|
||||
1
|
||||
| Node (v, l, h) ->
|
||||
assert ((0 <= v && v < maxn) || (Printf.printf "v = %d\n" v ; false)) ;
|
||||
(count l lsl (get_var l - v - 1)) + (count h lsl (get_var h - v - 1)))
|
||||
assert (0 <= v && v < card) ;
|
||||
let count_side s = count s lsl (v - get_order s - 1) in
|
||||
count_side h + count_side l)
|
||||
in
|
||||
fun bdd -> count bdd lsl get_var bdd
|
||||
fun bdd -> count bdd lsl (card - get_order bdd - 1)
|
||||
|
||||
let any_sat =
|
||||
let rec aux assign bdd =
|
||||
|
65
test/test.ml
Normal file
65
test/test.ml
Normal file
@ -0,0 +1,65 @@
|
||||
open Bdd
|
||||
|
||||
let _ =
|
||||
let order = get_order true_bdd in
|
||||
assert (order = -1) ;
|
||||
let order = get_order false_bdd in
|
||||
assert (order = -1) ;
|
||||
let one_var = var_bdd 0 in
|
||||
assert (get_order one_var = 0) ;
|
||||
assert (to_string one_var = "0 ? (true) : (false)") ;
|
||||
assert (of_bool true == true_bdd) ;
|
||||
assert (of_bool false == false_bdd) ;
|
||||
assert (neg true_bdd == false_bdd) ;
|
||||
assert (neg false_bdd == true_bdd) ;
|
||||
assert (to_string (neg one_var) = "0 ? (false) : (true)") ;
|
||||
let bdd = conj true_bdd true_bdd in
|
||||
assert (bdd == true_bdd) ;
|
||||
let bdd = disj false_bdd true_bdd in
|
||||
assert (bdd == true_bdd) ;
|
||||
let bdd = disj false_bdd one_var in
|
||||
assert (bdd == one_var) ;
|
||||
let bdd = imp false_bdd true_bdd in
|
||||
assert (bdd == true_bdd) ;
|
||||
let bdd = imp true_bdd false_bdd in
|
||||
assert (bdd == false_bdd) ;
|
||||
let bdd = eq false_bdd true_bdd in
|
||||
assert (bdd == false_bdd) ;
|
||||
let tbl = Hashtbl.create 32 in
|
||||
Hashtbl.add tbl 0 true ;
|
||||
assert (compute tbl one_var == true_bdd) ;
|
||||
Hashtbl.replace tbl 0 false ;
|
||||
assert (compute tbl one_var == false_bdd) ;
|
||||
Hashtbl.clear tbl ;
|
||||
assert (compute tbl one_var == one_var) ;
|
||||
assert (size true_bdd = 0) ;
|
||||
assert (size false_bdd = 0) ;
|
||||
assert (size one_var = 1) ;
|
||||
let one_var' = var_bdd 1 in
|
||||
let bdd = node 2 one_var' one_var in
|
||||
assert (size bdd = 2) ;
|
||||
assert (is_sat true_bdd = true) ;
|
||||
assert (is_sat false_bdd = false) ;
|
||||
assert (is_sat bdd = true) ;
|
||||
assert (count_sat 42 false_bdd = 0) ;
|
||||
assert (count_sat 0 true_bdd = 1) ;
|
||||
assert (count_sat 1 true_bdd = 2) ;
|
||||
assert (count_sat 2 true_bdd = 4) ;
|
||||
assert (count_sat 1 one_var = 1) ;
|
||||
assert (count_sat 2 one_var' = 2) ;
|
||||
assert (count_sat 3 bdd = 4) ;
|
||||
assert (count_sat 4 bdd = 8) ;
|
||||
assert (any_sat false_bdd = None) ;
|
||||
assert (any_sat true_bdd = Some []) ;
|
||||
assert (any_sat one_var = Some [(0, true)]) ;
|
||||
( match any_sat bdd with
|
||||
| None ->
|
||||
assert false
|
||||
| Some assign ->
|
||||
assert (
|
||||
List.sort compare assign = [(0, true); (2, true)]
|
||||
|| List.sort compare assign = [(1, true); (2, false)] ) ) ;
|
||||
assert (
|
||||
List.sort compare (List.map (fun el -> List.sort compare el) (all_sat bdd))
|
||||
= [[(0, true); (2, true)]; [(1, true); (2, false)]] ) ;
|
||||
Format.printf "Tests are OK !@."
|
Loading…
x
Reference in New Issue
Block a user