Browse Source

add tests, fix some algorithms

master
zapashcanon 1 year ago
parent
commit
10eaeefb54
Signed by: zapashcanon GPG Key ID: 8981C3C62D1D28F1
3 changed files with 131 additions and 69 deletions
  1. +63
    -69
      src/bdd.ml
  2. +3
    -0
      test/dune
  3. +65
    -0
      test/test.ml

+ 63
- 69
src/bdd.ml View File

@ -37,30 +37,21 @@ module Hash = struct
let hash b = b.tag
end
module Hash2 = struct
type t = hidden * hidden
let equal (x1, y1) (x2, y2) = x1 == x2 && y1 == y2
let hash (x, y) = x.tag + (19 * y.tag)
end
let hc = Hbdd.hashcons
let view x = x.node
module Mem = Memo.MakeWeak (Hash)
module Mem2 = Memo.MakeWeak (Hash)
let true_bdd = hc True
let false_bdd = hc False
let get_order bdd =
match view bdd with True | False -> max_int | Node (v, _, _) -> v
match view bdd with True | False -> -1 | Node (v, _, _) -> v
let node v l h =
if v >= get_order l || v >= get_order h then invalid_arg "node" ;
if v <= get_order l || v <= get_order h then invalid_arg "node" ;
if Hash.equal l h then l else hc (Node (v, l, h))
let var_bdd v = node v false_bdd true_bdd
@ -72,7 +63,14 @@ let rec fprintf fmt bdd =
| False ->
Format.fprintf fmt "false"
| Node (v, l, h) ->
Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf l fprintf h
Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf h fprintf l
let to_string bdd =
let buff = Buffer.create 512 in
let fmt = Format.formatter_of_buffer buff in
fprintf fmt bdd ;
Format.pp_print_flush fmt () ;
Buffer.contents buff
let of_bool = function true -> true_bdd | false -> false_bdd
@ -88,55 +86,53 @@ let neg x =
node var (neg low) (neg high))
x
let comb_comm op x y =
Mem2.memo
(fun comb_comm n1 n2 ->
match (view n1, view n2) with
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
node v1 (comb_comm l1 l2) (comb_comm h1 h2)
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
node v1 (comb_comm l1 n2) (comb_comm h1 n2)
| Node (_, _, _), Node (v2, l2, h2) ->
node v2 (comb_comm n1 l2) (comb_comm n1 h2)
| True, Node (v, l, h) | Node (v, l, h), True ->
node v (comb_comm l true_bdd) (comb_comm h true_bdd)
| False, Node (v, l, h) | Node (v, l, h), False ->
node v (comb_comm l false_bdd) (comb_comm h false_bdd)
| False, False ->
of_bool (op false false)
| False, True | True, False ->
of_bool (op true false)
| True, True ->
of_bool (op true true))
x y
let comb op x y =
Mem2.memo
(fun comb n1 n2 ->
match (view n1, view n2) with
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
node v1 (comb l1 l2) (comb h1 h2)
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
node v1 (comb l1 n2) (comb h1 n2)
| Node (_, _, _), Node (v2, l2, h2) ->
node v2 (comb n1 l2) (comb n1 h2)
| True, Node (v, l, h) ->
node v (comb true_bdd l) (comb true_bdd h)
| Node (v, l, h), True ->
node v (comb l true_bdd) (comb h true_bdd)
| False, Node (v, l, h) ->
node v (comb false_bdd l) (comb false_bdd h)
| Node (v, l, h), False ->
node v (comb l false_bdd) (comb h false_bdd)
| False, False ->
of_bool (op false false)
| False, True ->
of_bool (op false true)
| True, False ->
of_bool (op true false)
| True, True ->
of_bool (op true true))
x y
(* TODO: memo 2 ? *)
let rec comb_comm op n1 n2 =
let comb_comm = comb_comm op in
match (view n1, view n2) with
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
node v1 (comb_comm l1 l2) (comb_comm h1 h2)
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
node v1 (comb_comm l1 n2) (comb_comm h1 n2)
| Node (_, _, _), Node (v2, l2, h2) ->
node v2 (comb_comm n1 l2) (comb_comm n1 h2)
| True, Node (v, l, h) | Node (v, l, h), True ->
node v (comb_comm l true_bdd) (comb_comm h true_bdd)
| False, Node (v, l, h) | Node (v, l, h), False ->
node v (comb_comm l false_bdd) (comb_comm h false_bdd)
| False, False ->
of_bool (op false false)
| False, True | True, False ->
of_bool (op true false)
| True, True ->
of_bool (op true true)
(* TODO: memo2 ? *)
let rec comb op n1 n2 =
let comb = comb op in
match (view n1, view n2) with
| Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
node v1 (comb l1 l2) (comb h1 h2)
| Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
node v1 (comb l1 n2) (comb h1 n2)
| Node (_, _, _), Node (v2, l2, h2) ->
node v2 (comb n1 l2) (comb n1 h2)
| True, Node (v, l, h) ->
node v (comb true_bdd l) (comb true_bdd h)
| Node (v, l, h), True ->
node v (comb l true_bdd) (comb h true_bdd)
| False, Node (v, l, h) ->
node v (comb false_bdd l) (comb false_bdd h)
| Node (v, l, h), False ->
node v (comb l false_bdd) (comb h false_bdd)
| False, False ->
of_bool (op false false)
| False, True ->
of_bool (op false true)
| True, False ->
of_bool (op true false)
| True, True ->
of_bool (op true true)
let conj = comb_comm (fun x y -> x && y)
@ -156,7 +152,7 @@ let compute tbl =
| Node (v, l, h) -> (
match Hashtbl.find tbl v with
| exception Not_found ->
node v (compute_aux h) (compute_aux l)
node v (compute_aux l) (compute_aux h)
| true ->
compute_aux h
| false ->
@ -188,10 +184,7 @@ let size =
let is_sat bdd = match view bdd with False -> false | _ -> true
let count_sat maxn =
let get_var bdd =
match view bdd with False | True -> maxn | Node (v, _, _) -> v
in
let count_sat card =
let count =
Mem.memo (fun count bdd ->
match view bdd with
@ -200,10 +193,11 @@ let count_sat maxn =
| True ->
1
| Node (v, l, h) ->
assert ((0 <= v && v < maxn) || (Printf.printf "v = %d\n" v ; false)) ;
(count l lsl (get_var l - v - 1)) + (count h lsl (get_var h - v - 1)))
assert (0 <= v && v < card) ;
let count_side s = count s lsl (v - get_order s - 1) in
count_side h + count_side l)
in
fun bdd -> count bdd lsl get_var bdd
fun bdd -> count bdd lsl (card - get_order bdd - 1)
let any_sat =
let rec aux assign bdd =


+ 3
- 0
test/dune View File

@ -0,0 +1,3 @@
(test
(name test)
(libraries bdd))

+ 65
- 0
test/test.ml View File

@ -0,0 +1,65 @@
open Bdd
let _ =
let order = get_order true_bdd in
assert (order = -1) ;
let order = get_order false_bdd in
assert (order = -1) ;
let one_var = var_bdd 0 in
assert (get_order one_var = 0) ;
assert (to_string one_var = "0 ? (true) : (false)") ;
assert (of_bool true == true_bdd) ;
assert (of_bool false == false_bdd) ;
assert (neg true_bdd == false_bdd) ;
assert (neg false_bdd == true_bdd) ;
assert (to_string (neg one_var) = "0 ? (false) : (true)") ;
let bdd = conj true_bdd true_bdd in
assert (bdd == true_bdd) ;
let bdd = disj false_bdd true_bdd in
assert (bdd == true_bdd) ;
let bdd = disj false_bdd one_var in
assert (bdd == one_var) ;
let bdd = imp false_bdd true_bdd in
assert (bdd == true_bdd) ;
let bdd = imp true_bdd false_bdd in
assert (bdd == false_bdd) ;
let bdd = eq false_bdd true_bdd in
assert (bdd == false_bdd) ;
let tbl = Hashtbl.create 32 in
Hashtbl.add tbl 0 true ;
assert (compute tbl one_var == true_bdd) ;
Hashtbl.replace tbl 0 false ;
assert (compute tbl one_var == false_bdd) ;
Hashtbl.clear tbl ;
assert (compute tbl one_var == one_var) ;
assert (size true_bdd = 0) ;
assert (size false_bdd = 0) ;
assert (size one_var = 1) ;
let one_var' = var_bdd 1 in
let bdd = node 2 one_var' one_var in
assert (size bdd = 2) ;
assert (is_sat true_bdd = true) ;
assert (is_sat false_bdd = false) ;
assert (is_sat bdd = true) ;
assert (count_sat 42 false_bdd = 0) ;
assert (count_sat 0 true_bdd = 1) ;
assert (count_sat 1 true_bdd = 2) ;
assert (count_sat 2 true_bdd = 4) ;
assert (count_sat 1 one_var = 1) ;
assert (count_sat 2 one_var' = 2) ;
assert (count_sat 3 bdd = 4) ;
assert (count_sat 4 bdd = 8) ;
assert (any_sat false_bdd = None) ;
assert (any_sat true_bdd = Some []) ;
assert (any_sat one_var = Some [(0, true)]) ;
( match any_sat bdd with
| None ->
assert false
| Some assign ->
assert (
List.sort compare assign = [(0, true); (2, true)]
|| List.sort compare assign = [(1, true); (2, false)] ) ) ;
assert (
List.sort compare (List.map (fun el -> List.sort compare el) (all_sat bdd))
= [[(0, true); (2, true)]; [(1, true); (2, false)]] ) ;
Format.printf "Tests are OK !@."

Loading…
Cancel
Save