libqbe/tools/mgen/fuzz.ml

414 lines
12 KiB
OCaml
Raw Normal View History

(* fuzz the tables and matchers generated *)
open Match
module Buffer: sig
type 'a t
val create: ?capacity:int -> unit -> 'a t
val reset: 'a t -> unit
val size: 'a t -> int
val get: 'a t -> int -> 'a
val set: 'a t -> int -> 'a -> unit
val push: 'a t -> 'a -> unit
end = struct
type 'a t =
{ mutable size: int
; mutable data: 'a array }
let mk_array n = Array.make n (Obj.magic 0)
let create ?(capacity = 10) () =
if capacity < 0 then invalid_arg "Buffer.make";
{size = 0; data = mk_array capacity}
let reset b = b.size <- 0
let size b = b.size
let get b n =
if n >= size b then invalid_arg "Buffer.get";
b.data.(n)
let set b n x =
if n >= size b then invalid_arg "Buffer.set";
b.data.(n) <- x
let push b x =
let cap = Array.length b.data in
if size b = cap then begin
let data = mk_array (2 * cap + 1) in
Array.blit b.data 0 data 0 cap;
b.data <- data
end;
let sz = size b in
b.size <- sz + 1;
set b sz x
end
let binop_state n op s1 s2 =
let key = K (op, s1, s2) in
try StateMap.find key n.statemap
with Not_found -> atom_state n Tmp
type id = int
type term_data =
| Binop of op * id * id
| Leaf of atomic_pattern
type term =
{ id: id
; data: term_data
; state: p state }
let pp_term fmt (ta, id) =
let fpf x = Format.fprintf fmt x in
let rec pp _fmt id =
match ta.(id).data with
| Leaf (Con c) -> fpf "%Ld" c
| Leaf AnyCon -> fpf "$%d" id
| Leaf Tmp -> fpf "%%%d" id
| Binop (op, id1, id2) ->
fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
(show_op op) id ta.(id).state.id
pp id1 pp id2
in pp fmt id
(* A term pool is a deduplicated set of term
* that maintains nodes numbering using the
* statemap passed at creation time *)
module TermPool = struct
type t =
{ terms: term Buffer.t
; hcons: (term_data, id) Hashtbl.t
; numbr: numberer }
let create numbr =
{ terms = Buffer.create ()
; hcons = Hashtbl.create 100
; numbr }
let reset tp =
Buffer.reset tp.terms;
Hashtbl.clear tp.hcons
let size tp = Buffer.size tp.terms
let term tp id = Buffer.get tp.terms id
let mk_leaf tp atm =
let data = Leaf atm in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state = atom_state tp.numbr atm in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let mk_binop tp op t1 t2 =
let data = Binop (op, t1.id, t2.id) in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state =
binop_state tp.numbr op t1.state t2.state
in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let rec add_pattern tp = function
| Bnr (op, p1, p2) ->
let t1 = add_pattern tp p1 in
let t2 = add_pattern tp p2 in
mk_binop tp op t1 t2
| Atm atm -> mk_leaf tp atm
| Var (_, atm) -> add_pattern tp (Atm atm)
let explode_term tp id =
let rec aux tms n id =
let t = term tp id in
match t.data with
| Leaf _ -> (n, {t with id = n} :: tms)
| Binop (op, id1, id2) ->
let n1, tms = aux tms n id1 in
let n = n1 + 1 in
let n2, tms = aux tms n id2 in
let n = n2 + 1 in
(n, { t with data = Binop (op, n1, n2)
; id = n } :: tms)
in
let n, tms = aux [] 0 id in
Array.of_list (List.rev tms), n
end
module R = Random
(* uniform pick in a list *)
let list_pick l =
let rec aux n l x =
match l with
| [] -> x
| y :: l ->
if R.int (n + 1) = 0 then
aux (n + 1) l y
else
aux (n + 1) l x
in
match l with
| [] -> invalid_arg "list_pick"
| x :: l -> aux 1 l x
let term_pick ~numbr =
let ops =
if numbr.ops = [] then
numbr.ops <-
(StateMap.fold (fun k _ ops ->
match k with
| K (op, _, _) -> op :: ops)
numbr.statemap [] |> setify);
numbr.ops
in
let rec gen depth =
(* exponential probability for leaves to
* avoid skewing towards shallow terms *)
let atm_prob = 0.75 ** float_of_int depth in
if R.float 1.0 <= atm_prob || ops = [] then
let atom, st = list_pick numbr.atoms in
(st, Atm atom)
else
let op = list_pick ops in
let s1, t1 = gen (depth - 1) in
let s2, t2 = gen (depth - 1) in
( binop_state numbr op s1 s2
, Bnr (op, t1, t2) )
in fun ~depth -> gen depth
exception FuzzError
let rec pattern_depth = function
| Bnr (_, p1, p2) ->
1 + max (pattern_depth p1) (pattern_depth p2)
| Atm _ -> 0
| Var (_, atm) -> pattern_depth (Atm atm)
let ( %% ) a b =
1e2 *. float_of_int a /. float_of_int b
let progress ?(width = 50) msg pct =
Format.eprintf "\x1b[2K\r%!";
let progress_bar fmt =
let n =
let fwidth = float_of_int width in
1 + int_of_float (pct *. fwidth /. 1e2)
in
Format.fprintf fmt " %s%s %.0f%%@?"
(String.concat "" (List.init n (fun _ -> "")))
(String.make (max 0 (width - n)) '-')
pct
in
Format.kfprintf progress_bar
Format.err_formatter msg
let fuzz_numberer rules numbr =
(* pick twice the max pattern depth so we
* have a chance to find non-trivial numbers
* for the atomic patterns in the rules *)
let depth =
List.fold_left (fun depth r ->
max depth (pattern_depth r.pattern))
0 rules * 2
in
(* fuzz until the term pool we are constructing
* is no longer growing fast enough; or we just
* went through sufficiently many iterations *)
let max_iter = 1_000_000 in
let low_insert_rate = 1e-2 in
let tp = TermPool.create numbr in
let rec loop new_stats i =
let (_, _, insert_rate) = new_stats in
if insert_rate <= low_insert_rate then () else
if i >= max_iter then () else
(* periodically update stats *)
let new_stats =
let (num, cnt, rate) = new_stats in
if num land 1023 = 0 then
let rate =
0.5 *. (rate +. float_of_int cnt /. 1023.)
in
progress " insert_rate=%.1f%%"
(i %% max_iter) (rate *. 1e2);
(num + 1, 0, rate)
else new_stats
in
(* create a term and check that its number is
* accurate wrt the rules *)
let st, term = term_pick ~numbr ~depth in
let state_matched =
List.filter_map (fun cu ->
match cu with
| Top ("$" | "%") -> None
| Top name -> Some name
| _ -> None)
st.point |> setify
in
let rule_matched =
List.filter_map (fun r ->
if pattern_match r.pattern term then
Some r.name
else None)
rules |> setify
in
if state_matched <> rule_matched then begin
let open Format in
let pp_str_list =
let pp_sep fmt () = fprintf fmt ",@ " in
pp_print_list ~pp_sep pp_print_string
in
eprintf "@.@[<v2>fuzz error for %s"
(show_pattern term);
eprintf "@ @[state matched: %a@]"
pp_str_list state_matched;
eprintf "@ @[rule matched: %a@]"
pp_str_list rule_matched;
eprintf "@]@.";
raise FuzzError;
end;
if state_matched = [] then
loop new_stats (i + 1)
else
(* add to the term pool *)
let old_size = TermPool.size tp in
let _ = TermPool.add_pattern tp term in
let new_stats =
let (num, cnt, rate) = new_stats in
if TermPool.size tp <> old_size then
(num + 1, cnt + 1, rate)
else
(num + 1, cnt, rate)
in
loop new_stats (i + 1)
in
loop (1, 0, 1.0) 0;
Format.eprintf
"@.@[ generated %.3fMiB of test terms@]@."
(float_of_int (Obj.reachable_words (Obj.repr tp))
/. 128. /. 1024.);
tp
let rec run_matcher stk m (ta, id as t) =
let state id = ta.(id).state.id in
match m.Action.node with
| Action.Switch cases ->
let m =
try List.assoc (state id) cases
with Not_found -> failwith "no switch case"
in
run_matcher stk m t
| Action.Push (sym, m) ->
let l, r =
match ta.(id).data with
| Leaf _ -> failwith "push on leaf"
| Binop (_, l, r) -> (l, r)
in
if sym && state l > state r
then run_matcher (l :: stk) m (ta, r)
else run_matcher (r :: stk) m (ta, l)
| Action.Pop m -> begin
match stk with
| id :: stk -> run_matcher stk m (ta, id)
| [] -> failwith "pop on empty stack"
end
| Action.Set (v, m) ->
(v, id) :: run_matcher stk m t
| Action.Stop -> []
let rec term_match p (ta, id) =
let (|>>) x f =
match x with None -> None | Some x -> f x
in
let atom_match a =
match ta.(id).data with
| Leaf a' -> pattern_match (Atm a) (Atm a')
| Binop _ -> pattern_match (Atm a) (Atm Tmp)
in
match p with
| Var (v, a) when atom_match a ->
Some [(v, id)]
| Atm a when atom_match a -> Some []
| (Atm _ | Var _) -> None
| Bnr (op, pl, pr) -> begin
match ta.(id).data with
| Binop (op', idl, idr) when op' = op ->
term_match pl (ta, idl) |>> fun l1 ->
term_match pr (ta, idr) |>> fun l2 ->
Some (l1 @ l2)
| _ -> None
end
let test_matchers tp numbr rules =
let {statemap = sm; states = sa; _} = numbr in
let total = ref 0 in
let matchers =
let htbl = Hashtbl.create (Array.length sa) in
List.map (fun r -> (r.name, r.pattern)) rules |>
group_by_fst |>
List.iter (fun (r, ps) ->
total := !total + List.length ps;
let pm = (ps, lr_matcher sm sa rules r) in
sa |> Array.iter (fun s ->
if List.mem (Top r) s.point then
Hashtbl.add htbl s.id pm));
htbl
in
let seen = Hashtbl.create !total in
for id = 0 to TermPool.size tp - 1 do
if id land 1023 = 0 ||
id = TermPool.size tp - 1 then begin
progress
" coverage=%.1f%%"
(id %% TermPool.size tp)
(Hashtbl.length seen %% !total)
end;
let t = TermPool.explode_term tp id in
Hashtbl.find_all matchers
(TermPool.term tp id).state.id |>
List.iter (fun (ps, m) ->
let norm = List.fast_sort compare in
let ok =
match norm (run_matcher [] m t) with
| asn -> `Match (List.exists (fun p ->
match term_match p t with
| None -> false
| Some asn' ->
if asn = norm asn' then begin
Hashtbl.replace seen p ();
true
end else false) ps)
| exception e -> `RunFailure e
in
if ok <> `Match true then begin
let open Format in
let pp_asn fmt asn =
fprintf fmt "@[<h>";
pp_print_list
~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
(fun fmt (v, d) ->
fprintf fmt "@[%s←%d@]" v d)
fmt asn;
fprintf fmt "@]"
in
eprintf "@.@[<v2>matcher error for";
eprintf "@ @[%a@]" pp_term t;
begin match ok with
| `RunFailure e ->
eprintf "@ @[exception: %s@]"
(Printexc.to_string e)
| `Match (* false *) _ ->
let asn = run_matcher [] m t in
eprintf "@ @[assignment: %a@]"
pp_asn asn;
eprintf "@ @[<v2>could not match";
List.iter (fun p ->
eprintf "@ + @[%s@]"
(show_pattern p)) ps;
eprintf "@]"
end;
eprintf "@]@.";
raise FuzzError
end)
done;
Format.eprintf "@."