diff --git a/src/bdd.ml b/src/bdd.ml index 332e942..9b30377 100644 --- a/src/bdd.ml +++ b/src/bdd.ml @@ -37,30 +37,21 @@ module Hash = struct let hash b = b.tag end -module Hash2 = struct - type t = hidden * hidden - - let equal (x1, y1) (x2, y2) = x1 == x2 && y1 == y2 - - let hash (x, y) = x.tag + (19 * y.tag) -end - let hc = Hbdd.hashcons let view x = x.node module Mem = Memo.MakeWeak (Hash) -module Mem2 = Memo.MakeWeak (Hash) let true_bdd = hc True let false_bdd = hc False let get_order bdd = - match view bdd with True | False -> max_int | Node (v, _, _) -> v + match view bdd with True | False -> -1 | Node (v, _, _) -> v let node v l h = - if v >= get_order l || v >= get_order h then invalid_arg "node" ; + if v <= get_order l || v <= get_order h then invalid_arg "node" ; if Hash.equal l h then l else hc (Node (v, l, h)) let var_bdd v = node v false_bdd true_bdd @@ -72,7 +63,14 @@ let rec fprintf fmt bdd = | False -> Format.fprintf fmt "false" | Node (v, l, h) -> - Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf l fprintf h + Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf h fprintf l + +let to_string bdd = + let buff = Buffer.create 512 in + let fmt = Format.formatter_of_buffer buff in + fprintf fmt bdd ; + Format.pp_print_flush fmt () ; + Buffer.contents buff let of_bool = function true -> true_bdd | false -> false_bdd @@ -88,55 +86,53 @@ let neg x = node var (neg low) (neg high)) x -let comb_comm op x y = - Mem2.memo - (fun comb_comm n1 n2 -> - match (view n1, view n2) with - | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 -> - node v1 (comb_comm l1 l2) (comb_comm h1 h2) - | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 -> - node v1 (comb_comm l1 n2) (comb_comm h1 n2) - | Node (_, _, _), Node (v2, l2, h2) -> - node v2 (comb_comm n1 l2) (comb_comm n1 h2) - | True, Node (v, l, h) | Node (v, l, h), True -> - node v (comb_comm l true_bdd) (comb_comm h true_bdd) - | False, Node (v, l, h) | Node (v, l, h), False -> - node v (comb_comm l false_bdd) (comb_comm h false_bdd) - | False, False -> - of_bool (op false false) - | False, True | True, False -> - of_bool (op true false) - | True, True -> - of_bool (op true true)) - x y +(* TODO: memo 2 ? *) +let rec comb_comm op n1 n2 = + let comb_comm = comb_comm op in + match (view n1, view n2) with + | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 -> + node v1 (comb_comm l1 l2) (comb_comm h1 h2) + | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 -> + node v1 (comb_comm l1 n2) (comb_comm h1 n2) + | Node (_, _, _), Node (v2, l2, h2) -> + node v2 (comb_comm n1 l2) (comb_comm n1 h2) + | True, Node (v, l, h) | Node (v, l, h), True -> + node v (comb_comm l true_bdd) (comb_comm h true_bdd) + | False, Node (v, l, h) | Node (v, l, h), False -> + node v (comb_comm l false_bdd) (comb_comm h false_bdd) + | False, False -> + of_bool (op false false) + | False, True | True, False -> + of_bool (op true false) + | True, True -> + of_bool (op true true) -let comb op x y = - Mem2.memo - (fun comb n1 n2 -> - match (view n1, view n2) with - | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 -> - node v1 (comb l1 l2) (comb h1 h2) - | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 -> - node v1 (comb l1 n2) (comb h1 n2) - | Node (_, _, _), Node (v2, l2, h2) -> - node v2 (comb n1 l2) (comb n1 h2) - | True, Node (v, l, h) -> - node v (comb true_bdd l) (comb true_bdd h) - | Node (v, l, h), True -> - node v (comb l true_bdd) (comb h true_bdd) - | False, Node (v, l, h) -> - node v (comb false_bdd l) (comb false_bdd h) - | Node (v, l, h), False -> - node v (comb l false_bdd) (comb h false_bdd) - | False, False -> - of_bool (op false false) - | False, True -> - of_bool (op false true) - | True, False -> - of_bool (op true false) - | True, True -> - of_bool (op true true)) - x y +(* TODO: memo2 ? *) +let rec comb op n1 n2 = + let comb = comb op in + match (view n1, view n2) with + | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 -> + node v1 (comb l1 l2) (comb h1 h2) + | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 -> + node v1 (comb l1 n2) (comb h1 n2) + | Node (_, _, _), Node (v2, l2, h2) -> + node v2 (comb n1 l2) (comb n1 h2) + | True, Node (v, l, h) -> + node v (comb true_bdd l) (comb true_bdd h) + | Node (v, l, h), True -> + node v (comb l true_bdd) (comb h true_bdd) + | False, Node (v, l, h) -> + node v (comb false_bdd l) (comb false_bdd h) + | Node (v, l, h), False -> + node v (comb l false_bdd) (comb h false_bdd) + | False, False -> + of_bool (op false false) + | False, True -> + of_bool (op false true) + | True, False -> + of_bool (op true false) + | True, True -> + of_bool (op true true) let conj = comb_comm (fun x y -> x && y) @@ -156,7 +152,7 @@ let compute tbl = | Node (v, l, h) -> ( match Hashtbl.find tbl v with | exception Not_found -> - node v (compute_aux h) (compute_aux l) + node v (compute_aux l) (compute_aux h) | true -> compute_aux h | false -> @@ -188,10 +184,7 @@ let size = let is_sat bdd = match view bdd with False -> false | _ -> true -let count_sat maxn = - let get_var bdd = - match view bdd with False | True -> maxn | Node (v, _, _) -> v - in +let count_sat card = let count = Mem.memo (fun count bdd -> match view bdd with @@ -200,10 +193,11 @@ let count_sat maxn = | True -> 1 | Node (v, l, h) -> - assert ((0 <= v && v < maxn) || (Printf.printf "v = %d\n" v ; false)) ; - (count l lsl (get_var l - v - 1)) + (count h lsl (get_var h - v - 1))) + assert (0 <= v && v < card) ; + let count_side s = count s lsl (v - get_order s - 1) in + count_side h + count_side l) in - fun bdd -> count bdd lsl get_var bdd + fun bdd -> count bdd lsl (card - get_order bdd - 1) let any_sat = let rec aux assign bdd = diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..4f1ca19 --- /dev/null +++ b/test/dune @@ -0,0 +1,3 @@ +(test + (name test) + (libraries bdd)) diff --git a/test/test.ml b/test/test.ml new file mode 100644 index 0000000..299d13f --- /dev/null +++ b/test/test.ml @@ -0,0 +1,65 @@ +open Bdd + +let _ = + let order = get_order true_bdd in + assert (order = -1) ; + let order = get_order false_bdd in + assert (order = -1) ; + let one_var = var_bdd 0 in + assert (get_order one_var = 0) ; + assert (to_string one_var = "0 ? (true) : (false)") ; + assert (of_bool true == true_bdd) ; + assert (of_bool false == false_bdd) ; + assert (neg true_bdd == false_bdd) ; + assert (neg false_bdd == true_bdd) ; + assert (to_string (neg one_var) = "0 ? (false) : (true)") ; + let bdd = conj true_bdd true_bdd in + assert (bdd == true_bdd) ; + let bdd = disj false_bdd true_bdd in + assert (bdd == true_bdd) ; + let bdd = disj false_bdd one_var in + assert (bdd == one_var) ; + let bdd = imp false_bdd true_bdd in + assert (bdd == true_bdd) ; + let bdd = imp true_bdd false_bdd in + assert (bdd == false_bdd) ; + let bdd = eq false_bdd true_bdd in + assert (bdd == false_bdd) ; + let tbl = Hashtbl.create 32 in + Hashtbl.add tbl 0 true ; + assert (compute tbl one_var == true_bdd) ; + Hashtbl.replace tbl 0 false ; + assert (compute tbl one_var == false_bdd) ; + Hashtbl.clear tbl ; + assert (compute tbl one_var == one_var) ; + assert (size true_bdd = 0) ; + assert (size false_bdd = 0) ; + assert (size one_var = 1) ; + let one_var' = var_bdd 1 in + let bdd = node 2 one_var' one_var in + assert (size bdd = 2) ; + assert (is_sat true_bdd = true) ; + assert (is_sat false_bdd = false) ; + assert (is_sat bdd = true) ; + assert (count_sat 42 false_bdd = 0) ; + assert (count_sat 0 true_bdd = 1) ; + assert (count_sat 1 true_bdd = 2) ; + assert (count_sat 2 true_bdd = 4) ; + assert (count_sat 1 one_var = 1) ; + assert (count_sat 2 one_var' = 2) ; + assert (count_sat 3 bdd = 4) ; + assert (count_sat 4 bdd = 8) ; + assert (any_sat false_bdd = None) ; + assert (any_sat true_bdd = Some []) ; + assert (any_sat one_var = Some [(0, true)]) ; + ( match any_sat bdd with + | None -> + assert false + | Some assign -> + assert ( + List.sort compare assign = [(0, true); (2, true)] + || List.sort compare assign = [(1, true); (2, false)] ) ) ; + assert ( + List.sort compare (List.map (fun el -> List.sort compare el) (all_sat bdd)) + = [[(0, true); (2, true)]; [(1, true); (2, false)]] ) ; + Format.printf "Tests are OK !@."