mgen: match automatons and C generation
The algorithm to generate matchers took a long time to be discovered and refined to its present version. The rest of mgen is mostly boring engineering. Extensive fuzzing ensures that the two core components of mgen (tables and matchers generation) are correct on specific problem instances.
This commit is contained in:
parent
56e2263ca4
commit
a609527752
11 changed files with 2144 additions and 444 deletions
347
tools/match.ml
347
tools/match.ml
|
@ -1,347 +0,0 @@
|
|||
type cls = Kw | Kl | Ks | Kd
|
||||
type op_base =
|
||||
| Oadd
|
||||
| Osub
|
||||
| Omul
|
||||
type op = cls * op_base
|
||||
|
||||
let commutative = function
|
||||
| (_, (Oadd | Omul)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
let associative = function
|
||||
| (_, (Oadd | Omul)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
type atomic_pattern =
|
||||
| Tmp
|
||||
| AnyCon
|
||||
| Con of int64
|
||||
|
||||
type pattern =
|
||||
| Bnr of op * pattern * pattern
|
||||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let show_op (k, o) =
|
||||
(match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul") ^
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
|
||||
let rec show_pattern p =
|
||||
match p with
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm Tmp -> "%"
|
||||
| Atm AnyCon -> "$"
|
||||
| Atm (Con n) -> Int64.to_string n
|
||||
| Bnr (o, pl, pr) ->
|
||||
"(" ^ show_op o ^
|
||||
" " ^ show_pattern pl ^
|
||||
" " ^ show_pattern pr ^ ")"
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm Tmp ->
|
||||
begin match w with
|
||||
| Atm (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
end
|
||||
| Atm (Con _) -> w = p
|
||||
| Atm (AnyCon) ->
|
||||
not (pattern_match (Atm Tmp) w)
|
||||
| Bnr (o, pl, pr) ->
|
||||
begin match w with
|
||||
| Bnr (o', wl, wr) ->
|
||||
o' = o &&
|
||||
pattern_match pl wl &&
|
||||
pattern_match pr wr
|
||||
| _ -> false
|
||||
end
|
||||
|
||||
type 'a cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * 'a cursor * pattern
|
||||
| Bnrr of op * pattern * 'a cursor
|
||||
| Top of 'a
|
||||
|
||||
let rec fold_cursor c p =
|
||||
match c with
|
||||
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
|
||||
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
|
||||
| Top _ -> p
|
||||
|
||||
let peel p x =
|
||||
let once out (p, c) =
|
||||
match p with
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm _ -> (p, c) :: out
|
||||
| Bnr (o, pl, pr) ->
|
||||
(pl, Bnrl (o, c, pr)) ::
|
||||
(pr, Bnrr (o, pl, c)) :: out
|
||||
in
|
||||
let rec go l =
|
||||
let l' = List.fold_left once [] l in
|
||||
if List.length l' = List.length l
|
||||
then l
|
||||
else go l'
|
||||
in go [(p, Top x)]
|
||||
|
||||
let fold_pairs l1 l2 ini f =
|
||||
let rec go acc = function
|
||||
| [] -> acc
|
||||
| a :: l1' ->
|
||||
go (List.fold_left
|
||||
(fun acc b -> f (a, b) acc)
|
||||
acc l2) l1'
|
||||
in go ini l1
|
||||
|
||||
let iter_pairs l f =
|
||||
fold_pairs l l () (fun x () -> f x)
|
||||
|
||||
type 'a state =
|
||||
{ id: int
|
||||
; seen: pattern
|
||||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.filter_map (fun c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||
| _ -> None)
|
||||
point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
compare a b) l |>
|
||||
List.fold_left (fun (oo, l, res) (o', c) ->
|
||||
match oo with
|
||||
| None -> (Some o', [c], [])
|
||||
| Some o when o = o' -> (oo, c :: l, res)
|
||||
| Some o -> (Some o', [c], (o, l) :: res))
|
||||
(None, [], []) |>
|
||||
(function
|
||||
| (None, _, _) -> []
|
||||
| (Some o, l, res) -> (o, l) :: res)
|
||||
|
||||
let sort_uniq cmp l =
|
||||
List.fast_sort cmp l |>
|
||||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e when cmp e e' = 0 -> (eo, l)
|
||||
| Some e -> (Some e', e :: l))
|
||||
(None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
||||
let normalize (point: ('a cursor) list) =
|
||||
sort_uniq compare point
|
||||
|
||||
let next_binary tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst in
|
||||
let o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = 0
|
||||
; seen = Bnr (o, s1.seen, s2.seen)
|
||||
; point = normalize (l @ tmp)
|
||||
}) (group_by_fst (o1 @ o2))
|
||||
|
||||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type t
|
||||
val create: unit -> t
|
||||
val add: t -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: t -> (p state -> unit) -> unit
|
||||
val elems: t -> (p state) list
|
||||
end = struct
|
||||
open Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type nonrec t =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 1 }
|
||||
let add set s =
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
`Found, {s with id}
|
||||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
Printf.printf "adding: %d [%s]\n"
|
||||
id (show_pattern s.seen);
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
let iter set f =
|
||||
let f s id = f {s with id} in
|
||||
iter f set.h
|
||||
let elems set =
|
||||
let res = ref [] in
|
||||
iter set (fun s -> res := s :: !res);
|
||||
!res
|
||||
end
|
||||
|
||||
type table_key =
|
||||
| K of op * p state * p state
|
||||
|
||||
module StateMap = Map.Make(struct
|
||||
type t = table_key
|
||||
let compare ka kb =
|
||||
match ka, kb with
|
||||
| K (o, sl, sr), K (o', sl', sr') ->
|
||||
compare (o, sl.id, sr.id)
|
||||
(o', sl'.id, sr'.id)
|
||||
end)
|
||||
|
||||
type rule =
|
||||
{ name: string
|
||||
; pattern: pattern
|
||||
}
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.concat_map
|
||||
(fun r -> peel r.pattern r.name) rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let find x d l =
|
||||
try List.assoc x l with Not_found -> d in
|
||||
let tmp = find (Atm Tmp) [] ground in
|
||||
let con = find (Atm AnyCon) [] ground in
|
||||
let () =
|
||||
List.iter (fun (seen, l) ->
|
||||
let point =
|
||||
if pattern_match (Atm Tmp) seen
|
||||
then normalize (tmp @ l)
|
||||
else normalize (con @ l)
|
||||
in
|
||||
let s = {id = 0; seen; point} in
|
||||
let flag, _ = StateSet.add states s in
|
||||
assert (flag = `Added)
|
||||
) ground
|
||||
in
|
||||
(* setup loop state *)
|
||||
let map = ref StateMap.empty in
|
||||
let map_add k s' =
|
||||
map := StateMap.add k s' !map
|
||||
in
|
||||
let flag = ref `Added in
|
||||
let flagmerge = function
|
||||
| `Added -> flag := `Added
|
||||
| _ -> ()
|
||||
in
|
||||
(* iterate until fixpoint *)
|
||||
while !flag = `Added do
|
||||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
next_binary tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
flagmerge flag';
|
||||
map_add (K (o, sl, sr)) s';
|
||||
));
|
||||
done;
|
||||
(StateSet.elems states, !map)
|
||||
|
||||
let intersperse x l =
|
||||
let rec go left right out =
|
||||
let out =
|
||||
(List.rev left @ [x] @ right) ::
|
||||
out in
|
||||
match right with
|
||||
| x :: right' ->
|
||||
go (x :: left) right' out
|
||||
| [] -> out
|
||||
in go [] l []
|
||||
|
||||
let rec permute = function
|
||||
| [] -> [[]]
|
||||
| x :: l ->
|
||||
List.concat (List.map
|
||||
(intersperse x) (permute l))
|
||||
|
||||
(* build all binary trees with ordered
|
||||
* leaves l *)
|
||||
let rec bins build l =
|
||||
let rec go l r out =
|
||||
match r with
|
||||
| [] -> out
|
||||
| x :: r' ->
|
||||
go (l @ [x]) r'
|
||||
(fold_pairs
|
||||
(bins build l)
|
||||
(bins build r)
|
||||
out (fun (l, r) out ->
|
||||
build l r :: out))
|
||||
in
|
||||
match l with
|
||||
| [] -> []
|
||||
| [x] -> [x]
|
||||
| x :: l -> go [x] l []
|
||||
|
||||
let products l ini f =
|
||||
let rec go acc la = function
|
||||
| [] -> f (List.rev la) acc
|
||||
| xs :: l ->
|
||||
List.fold_left (fun acc x ->
|
||||
go acc (x :: la) l)
|
||||
acc xs
|
||||
in go ini [] l
|
||||
|
||||
(* combinatorial nuke... *)
|
||||
let rec ac_equiv =
|
||||
let rec alevel o = function
|
||||
| Bnr (o', l, r) when o' = o ->
|
||||
alevel o l @ alevel o r
|
||||
| x -> [x]
|
||||
in function
|
||||
| Bnr (o, _, _) as p
|
||||
when associative o ->
|
||||
products
|
||||
(List.map ac_equiv (alevel o p)) []
|
||||
(fun choice out ->
|
||||
List.map
|
||||
(bins (fun l r -> Bnr (o, l, r)))
|
||||
(if commutative o
|
||||
then permute choice
|
||||
else [choice]) |>
|
||||
List.concat |>
|
||||
(fun l -> List.rev_append l out))
|
||||
| Bnr (o, l, r)
|
||||
when commutative o ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) ::
|
||||
Bnr (o, r, l) :: out)
|
||||
| Bnr (o, l, r) ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) :: out)
|
||||
| x -> [x]
|
|
@ -1,97 +0,0 @@
|
|||
#use "match.ml"
|
||||
|
||||
let test_pattern_match =
|
||||
let pm = pattern_match
|
||||
and nm = fun x y -> not (pattern_match x y) in
|
||||
begin
|
||||
assert (nm (Atm Tmp) (Atm (Con 42L)));
|
||||
assert (pm (Atm AnyCon) (Atm (Con 42L)));
|
||||
assert (nm (Atm (Con 42L)) (Atm AnyCon));
|
||||
assert (nm (Atm (Con 42L)) (Atm Tmp));
|
||||
end
|
||||
|
||||
let test_peel =
|
||||
let o = Kw, Oadd in
|
||||
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 42L)) in
|
||||
let l = peel p () in
|
||||
let () = assert (List.length l = 3) in
|
||||
let atomic_p (p, _) =
|
||||
match p with Atm _ -> true | _ -> false in
|
||||
let () = assert (List.for_all atomic_p l) in
|
||||
let l = List.map (fun (p, c) -> fold_cursor c p) l in
|
||||
let () = assert (List.for_all ((=) p) l) in
|
||||
()
|
||||
|
||||
let test_fold_pairs =
|
||||
let l = [1; 2; 3; 4; 5] in
|
||||
let p = fold_pairs l l [] (fun a b -> a :: b) in
|
||||
let () = assert (List.length p = 25) in
|
||||
let p = sort_uniq compare p in
|
||||
let () = assert (List.length p = 25) in
|
||||
()
|
||||
|
||||
(* test pattern & state *)
|
||||
let tp =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 0L))
|
||||
let ts =
|
||||
{ id = 0
|
||||
; seen = Atm Tmp
|
||||
; point =
|
||||
List.map snd
|
||||
(List.filter (fun (p, _) -> p = Atm Tmp)
|
||||
(peel tp ()))
|
||||
}
|
||||
|
||||
let print_sm =
|
||||
StateMap.iter (fun k s' ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
let top =
|
||||
List.fold_left (fun top c ->
|
||||
match c with
|
||||
| Top r -> top ^ " " ^ r
|
||||
| _ -> top) "" s'.point
|
||||
in
|
||||
Printf.printf
|
||||
"(%s %d %d) -> %d%s\n"
|
||||
(show_op o)
|
||||
sl.id sr.id s'.id top)
|
||||
|
||||
let rules =
|
||||
let oa = Kl, Oadd in
|
||||
let om = Kl, Omul in
|
||||
match `X64Addr with
|
||||
(* ------------------------------- *)
|
||||
| `X64Addr ->
|
||||
let rule name pattern =
|
||||
List.mapi (fun i pattern ->
|
||||
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
|
||||
; pattern })
|
||||
(ac_equiv pattern) in
|
||||
(* o + b *)
|
||||
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||
@ (* b + s * i *)
|
||||
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
@ (* o + s * i *)
|
||||
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
@ (* b + o + s * i *)
|
||||
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
(* ------------------------------- *)
|
||||
| `Add3 ->
|
||||
[ { name = "add"
|
||||
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
|
||||
[ { name = "add"
|
||||
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
|
||||
[ { name = "mul"
|
||||
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
|
||||
Atm Tmp),
|
||||
Bnr (oa, Atm Tmp,
|
||||
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
|
||||
|
||||
|
||||
let sl, sm = generate_table rules
|
||||
let s n = List.find (fun {id; _} -> id = n) sl
|
||||
let () = print_sm sm
|
3
tools/mgen/.gitignore
vendored
Normal file
3
tools/mgen/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*.cm[iox]
|
||||
*.o
|
||||
mgen
|
1
tools/mgen/.ocp-indent
Normal file
1
tools/mgen/.ocp-indent
Normal file
|
@ -0,0 +1 @@
|
|||
match_clause=4
|
16
tools/mgen/Makefile
Normal file
16
tools/mgen/Makefile
Normal file
|
@ -0,0 +1,16 @@
|
|||
BIN = mgen
|
||||
SRC = \
|
||||
match.ml \
|
||||
fuzz.ml \
|
||||
cgen.ml \
|
||||
sexp.ml \
|
||||
test.ml \
|
||||
main.ml
|
||||
|
||||
$(BIN): $(SRC)
|
||||
ocamlopt -o $(BIN) -g str.cmxa $(SRC)
|
||||
|
||||
clean:
|
||||
rm -f *.cm? *.o $(BIN)
|
||||
|
||||
.PHONY: clean
|
420
tools/mgen/cgen.ml
Normal file
420
tools/mgen/cgen.ml
Normal file
|
@ -0,0 +1,420 @@
|
|||
open Match
|
||||
|
||||
type options =
|
||||
{ pfx: string
|
||||
; static: bool
|
||||
; oc: out_channel }
|
||||
|
||||
type side = L | R
|
||||
|
||||
type id_pred =
|
||||
| InBitSet of Int64.t
|
||||
| Ge of int
|
||||
| Eq of int
|
||||
|
||||
and id_test =
|
||||
| Pred of (side * id_pred)
|
||||
| And of id_test * id_test
|
||||
|
||||
type case_code =
|
||||
| Table of ((int * int) * int) list
|
||||
| IfThen of
|
||||
{ test: id_test
|
||||
; cif: case_code
|
||||
; cthen: case_code option }
|
||||
| Return of int
|
||||
|
||||
type case =
|
||||
{ swap: bool
|
||||
; code: case_code }
|
||||
|
||||
let cgen_case tmp nstates map =
|
||||
let cgen_test ids =
|
||||
match ids with
|
||||
| [id] -> Eq id
|
||||
| _ ->
|
||||
let min_id =
|
||||
List.fold_left min max_int ids in
|
||||
if List.length ids = nstates - min_id
|
||||
then Ge min_id
|
||||
else begin
|
||||
assert (nstates <= 64);
|
||||
InBitSet
|
||||
(List.fold_left (fun bs id ->
|
||||
Int64.logor bs
|
||||
(Int64.shift_left 1L id))
|
||||
0L ids)
|
||||
end
|
||||
in
|
||||
let symmetric =
|
||||
let inverse ((l, r), x) = ((r, l), x) in
|
||||
setify map = setify (List.map inverse map) in
|
||||
let map =
|
||||
let ordered ((l, r), _) = r <= l in
|
||||
if symmetric then
|
||||
List.filter ordered map
|
||||
else map
|
||||
in
|
||||
let exception BailToTable in
|
||||
try
|
||||
let st =
|
||||
match setify (List.map snd map) with
|
||||
| [st] -> st
|
||||
| _ -> raise BailToTable
|
||||
in
|
||||
(* the operation considered can only
|
||||
* generate a single state *)
|
||||
let pairs = List.map fst map in
|
||||
let ls, rs = List.split pairs in
|
||||
let ls = setify ls and rs = setify rs in
|
||||
if List.length ls > 1 && List.length rs > 1 then
|
||||
raise BailToTable;
|
||||
{ swap = symmetric
|
||||
; code =
|
||||
let pl = Pred (L, cgen_test ls)
|
||||
and pr = Pred (R, cgen_test rs) in
|
||||
IfThen
|
||||
{ test = And (pl, pr)
|
||||
; cif = Return st
|
||||
; cthen = Some (Return tmp) } }
|
||||
with BailToTable ->
|
||||
{ swap = symmetric
|
||||
; code = Table map }
|
||||
|
||||
let show_op (_cls, op) =
|
||||
"O" ^ show_op_base op
|
||||
|
||||
let indent oc i =
|
||||
Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
|
||||
|
||||
let emit_swap oc i =
|
||||
let pf m = Printf.fprintf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
pfi i "if (l < r)\n";
|
||||
pfi (i+1) "t = l, l = r, r = t;\n"
|
||||
|
||||
let gen_tables oc tmp pfx nstates (op, c) =
|
||||
let i = 1 in
|
||||
let pf m = Printf.fprintf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
let ntables = ref 0 in
|
||||
(* we must follow the order in which
|
||||
* we visit code in emit_case, or
|
||||
* else ntables goes out of sync *)
|
||||
let base = pfx ^ show_op op in
|
||||
let swap = c.swap in
|
||||
let rec gen c =
|
||||
match c with
|
||||
| Table map ->
|
||||
let name =
|
||||
if !ntables = 0 then base else
|
||||
base ^ string_of_int !ntables
|
||||
in
|
||||
assert (nstates <= 256);
|
||||
if swap then
|
||||
let n = nstates * (nstates + 1) / 2 in
|
||||
pfi i "static uchar %stbl[%d] = {\n" name n
|
||||
else
|
||||
pfi i "static uchar %stbl[%d][%d] = {\n"
|
||||
name nstates nstates;
|
||||
for l = 0 to nstates - 1 do
|
||||
pfi (i+1) "";
|
||||
for r = 0 to nstates - 1 do
|
||||
if not swap || r <= l then
|
||||
begin
|
||||
pf "%d"
|
||||
(try List.assoc (l,r) map
|
||||
with Not_found -> tmp);
|
||||
pf ",";
|
||||
end
|
||||
done;
|
||||
pf "\n";
|
||||
done;
|
||||
pfi i "};\n"
|
||||
| IfThen {cif; cthen} ->
|
||||
gen cif;
|
||||
Option.iter gen cthen
|
||||
| Return _ -> ()
|
||||
in
|
||||
gen c.code
|
||||
|
||||
let emit_case oc pfx no_swap (op, c) =
|
||||
let fpf = Printf.fprintf in
|
||||
let pf m = fpf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
let rec side oc = function
|
||||
| L -> fpf oc "l"
|
||||
| R -> fpf oc "r"
|
||||
in
|
||||
let pred oc (s, pred) =
|
||||
match pred with
|
||||
| InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
|
||||
| Eq id -> fpf oc "%a == %d" side s id
|
||||
| Ge id -> fpf oc "%d <= %a" id side s
|
||||
in
|
||||
let base = pfx ^ show_op op in
|
||||
let swap = c.swap in
|
||||
let ntables = ref 0 in
|
||||
let rec code i c =
|
||||
match c with
|
||||
| Return id -> pfi i "return %d;\n" id
|
||||
| Table map ->
|
||||
let name =
|
||||
if !ntables = 0 then base else
|
||||
base ^ string_of_int !ntables
|
||||
in
|
||||
incr ntables;
|
||||
if swap then
|
||||
pfi i "return %stbl[(l + l*l)/2 + r];\n" name
|
||||
else pfi i "return %stbl[l][r];\n" name
|
||||
| IfThen ({test = And (And (t1, t2), t3)} as r) ->
|
||||
code i @@ IfThen
|
||||
{r with test = And (t1, And (t2, t3))}
|
||||
| IfThen {test = And (Pred p, t); cif; cthen} ->
|
||||
pfi i "if (%a)\n" pred p;
|
||||
code i (IfThen {test = t; cif; cthen})
|
||||
| IfThen {test = Pred p; cif; cthen} ->
|
||||
pfi i "if (%a) {\n" pred p;
|
||||
code (i+1) cif;
|
||||
pfi i "}\n";
|
||||
Option.iter (code i) cthen
|
||||
in
|
||||
pfi 1 "case %s:\n" (show_op op);
|
||||
if not no_swap && c.swap then
|
||||
emit_swap oc 2;
|
||||
code 2 c.code
|
||||
|
||||
let emit_list
|
||||
?(limit=60) ?(cut_before_sep=false)
|
||||
~col ~indent:i ~sep ~f oc l =
|
||||
let sl = String.length sep in
|
||||
let rstripped_sep, rssl =
|
||||
if sep.[sl - 1] = ' ' then
|
||||
String.sub sep 0 (sl - 1), sl - 1
|
||||
else sep, sl
|
||||
in
|
||||
let lstripped_sep, lssl =
|
||||
if sep.[0] = ' ' then
|
||||
String.sub sep 1 (sl - 1), sl - 1
|
||||
else sep, sl
|
||||
in
|
||||
let rec line col acc = function
|
||||
| [] -> (List.rev acc, [])
|
||||
| s :: l ->
|
||||
let col = col + sl + String.length s in
|
||||
let no_space =
|
||||
if cut_before_sep || l = [] then
|
||||
col > limit
|
||||
else
|
||||
col + rssl > limit
|
||||
in
|
||||
if no_space then
|
||||
(List.rev acc, s :: l)
|
||||
else
|
||||
line col (s :: acc) l
|
||||
in
|
||||
let rec go col l =
|
||||
if l = [] then () else
|
||||
let ll, l = line col [] l in
|
||||
Printf.fprintf oc "%s" (String.concat sep ll);
|
||||
if l <> [] && cut_before_sep then begin
|
||||
Printf.fprintf oc "\n";
|
||||
indent oc i;
|
||||
Printf.fprintf oc "%s" lstripped_sep;
|
||||
go (8*i + lssl) l
|
||||
end else if l <> [] then begin
|
||||
Printf.fprintf oc "%s\n" rstripped_sep;
|
||||
indent oc i;
|
||||
go (8*i) l
|
||||
end else ()
|
||||
in
|
||||
go col (List.map f l)
|
||||
|
||||
let emit_numberer opts n =
|
||||
let pf m = Printf.fprintf opts.oc m in
|
||||
let tmp = (atom_state n Tmp).id in
|
||||
let con = (atom_state n AnyCon).id in
|
||||
let nst = Array.length n.states in
|
||||
let cases =
|
||||
StateMap.by_ops n.statemap |>
|
||||
List.map (fun (op, map) ->
|
||||
(op, cgen_case tmp nst map))
|
||||
in
|
||||
let all_swap =
|
||||
List.for_all (fun (_, c) -> c.swap) cases in
|
||||
(* opn() *)
|
||||
if opts.static then pf "static ";
|
||||
pf "int\n";
|
||||
pf "%sopn(int op, int l, int r)\n" opts.pfx;
|
||||
pf "{\n";
|
||||
cases |> List.iter
|
||||
(gen_tables opts.oc tmp opts.pfx nst);
|
||||
if List.exists (fun (_, c) -> c.swap) cases then
|
||||
pf "\tint t;\n\n";
|
||||
if all_swap then emit_swap opts.oc 1;
|
||||
pf "\tswitch (op) {\n";
|
||||
cases |> List.iter
|
||||
(emit_case opts.oc opts.pfx all_swap);
|
||||
pf "\tdefault:\n";
|
||||
pf "\t\treturn %d;\n" tmp;
|
||||
pf "\t}\n";
|
||||
pf "}\n\n";
|
||||
(* refn() *)
|
||||
if opts.static then pf "static ";
|
||||
pf "int\n";
|
||||
pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
|
||||
pf "{\n";
|
||||
let cons =
|
||||
List.filter_map (function
|
||||
| (Con c, s) -> Some (c, s.id)
|
||||
| _ -> None)
|
||||
n.atoms
|
||||
in
|
||||
if cons <> [] then
|
||||
pf "\tint64_t n;\n\n";
|
||||
pf "\tswitch (rtype(r)) {\n";
|
||||
pf "\tcase RTmp:\n";
|
||||
if tmp <> 0 then begin
|
||||
assert
|
||||
(List.exists (fun (_, s) ->
|
||||
s.id = 0
|
||||
) n.atoms &&
|
||||
(* no temp should ever get state 0 *)
|
||||
List.for_all (fun (a, s) ->
|
||||
s.id <> 0 ||
|
||||
match a with
|
||||
| AnyCon | Con _ -> true
|
||||
| _ -> false
|
||||
) n.atoms);
|
||||
pf "\t\tif (!tn[r.val].n)\n";
|
||||
pf "\t\t\ttn[r.val].n = %d;\n" tmp;
|
||||
end;
|
||||
pf "\t\treturn tn[r.val].n;\n";
|
||||
pf "\tcase RCon:\n";
|
||||
if cons <> [] then begin
|
||||
pf "\t\tif (con[r.val].type != CBits)\n";
|
||||
pf "\t\t\treturn %d;\n" con;
|
||||
pf "\t\tn = con[r.val].bits.i;\n";
|
||||
cons |> inverse |> group_by_fst
|
||||
|> List.iter (fun (id, cs) ->
|
||||
pf "\t\tif (";
|
||||
emit_list ~cut_before_sep:true
|
||||
~col:20 ~indent:2 ~sep:" || "
|
||||
~f:(fun c -> "n == " ^ Int64.to_string c)
|
||||
opts.oc cs;
|
||||
pf ")\n";
|
||||
pf "\t\t\treturn %d;\n" id
|
||||
);
|
||||
end;
|
||||
pf "\t\treturn %d;\n" con;
|
||||
pf "\tdefault:\n";
|
||||
pf "\t\treturn INT_MIN;\n";
|
||||
pf "\t}\n";
|
||||
pf "}\n\n";
|
||||
(* match[]: patterns per state *)
|
||||
if opts.static then pf "static ";
|
||||
pf "bits %smatch[%d] = {\n" opts.pfx nst;
|
||||
n.states |> Array.iteri (fun sn s ->
|
||||
let tops =
|
||||
List.filter_map (function
|
||||
| Top ("$" | "%") -> None
|
||||
| Top r -> Some ("BIT(P" ^ r ^ ")")
|
||||
| _ -> None) s.point |> setify
|
||||
in
|
||||
if tops <> [] then
|
||||
pf "\t[%d] = %s,\n"
|
||||
sn (String.concat " | " tops);
|
||||
);
|
||||
pf "};\n\n"
|
||||
|
||||
let var_id vars f =
|
||||
List.mapi (fun i x -> (x, i)) vars |>
|
||||
List.assoc f
|
||||
|
||||
let compile_action vars act =
|
||||
let pcs = Hashtbl.create 100 in
|
||||
let rec gen pc (act: Action.t) =
|
||||
try
|
||||
[10 + Hashtbl.find pcs act.id]
|
||||
with Not_found ->
|
||||
let code =
|
||||
match act.node with
|
||||
| Action.Stop ->
|
||||
[0]
|
||||
| Action.Push (sym, k) ->
|
||||
let c = if sym then 1 else 2 in
|
||||
[c] @ gen (pc + 1) k
|
||||
| Action.Set (v, {node = Action.Pop k; _})
|
||||
| Action.Set (v, ({node = Action.Stop; _} as k)) ->
|
||||
let v = var_id vars v in
|
||||
[3; v] @ gen (pc + 2) k
|
||||
| Action.Set _ ->
|
||||
(* for now, only atomic patterns can be
|
||||
* tied to a variable, so Set must be
|
||||
* followed by either Pop or Stop *)
|
||||
assert false
|
||||
| Action.Pop k ->
|
||||
[4] @ gen (pc + 1) k
|
||||
| Action.Switch cases ->
|
||||
let cases =
|
||||
inverse cases |> group_by_fst |>
|
||||
List.sort (fun (_, cs1) (_, cs2) ->
|
||||
let n1 = List.length cs1
|
||||
and n2 = List.length cs2 in
|
||||
compare n2 n1)
|
||||
in
|
||||
(* the last case is the one with
|
||||
* the max number of entries *)
|
||||
let cases = List.rev (List.tl cases)
|
||||
and last = fst (List.hd cases) in
|
||||
let ncases =
|
||||
List.fold_left (fun n (_, cs) ->
|
||||
List.length cs + n)
|
||||
0 cases
|
||||
in
|
||||
let body_off = 2 + 2 * ncases + 1 in
|
||||
let pc, tbl, body =
|
||||
List.fold_left
|
||||
(fun (pc, tbl, body) (a, cs) ->
|
||||
let ofs = body_off + List.length body in
|
||||
let case = gen pc a in
|
||||
let pc = pc + List.length case in
|
||||
let body = body @ case in
|
||||
let tbl =
|
||||
List.fold_left (fun tbl c ->
|
||||
tbl @ [c; ofs]
|
||||
) tbl cs
|
||||
in
|
||||
(pc, tbl, body))
|
||||
(pc + body_off, [], [])
|
||||
cases
|
||||
in
|
||||
let ofs = body_off + List.length body in
|
||||
let tbl = tbl @ [ofs] in
|
||||
assert (2 + List.length tbl = body_off);
|
||||
[5; ncases] @ tbl @ body @ gen pc last
|
||||
in
|
||||
if act.node <> Action.Stop then
|
||||
Hashtbl.replace pcs act.id pc;
|
||||
code
|
||||
in
|
||||
gen 0 act
|
||||
|
||||
let emit_matchers opts ms =
|
||||
let pf m = Printf.fprintf opts.oc m in
|
||||
if opts.static then pf "static ";
|
||||
pf "uchar *%smatcher[] = {\n" opts.pfx;
|
||||
List.iter (fun (vars, pname, m) ->
|
||||
pf "\t[P%s] = (uchar[]){\n" pname;
|
||||
pf "\t\t";
|
||||
let bytes = compile_action vars m in
|
||||
emit_list
|
||||
~col:16 ~indent:2 ~sep:","
|
||||
~f:string_of_int opts.oc bytes;
|
||||
pf "\n";
|
||||
pf "\t},\n")
|
||||
ms;
|
||||
pf "};\n\n"
|
||||
|
||||
let emit_c opts n =
|
||||
emit_numberer opts n
|
413
tools/mgen/fuzz.ml
Normal file
413
tools/mgen/fuzz.ml
Normal file
|
@ -0,0 +1,413 @@
|
|||
(* fuzz the tables and matchers generated *)
|
||||
open Match
|
||||
|
||||
module Buffer: sig
|
||||
type 'a t
|
||||
val create: ?capacity:int -> unit -> 'a t
|
||||
val reset: 'a t -> unit
|
||||
val size: 'a t -> int
|
||||
val get: 'a t -> int -> 'a
|
||||
val set: 'a t -> int -> 'a -> unit
|
||||
val push: 'a t -> 'a -> unit
|
||||
end = struct
|
||||
type 'a t =
|
||||
{ mutable size: int
|
||||
; mutable data: 'a array }
|
||||
let mk_array n = Array.make n (Obj.magic 0)
|
||||
let create ?(capacity = 10) () =
|
||||
if capacity < 0 then invalid_arg "Buffer.make";
|
||||
{size = 0; data = mk_array capacity}
|
||||
let reset b = b.size <- 0
|
||||
let size b = b.size
|
||||
let get b n =
|
||||
if n >= size b then invalid_arg "Buffer.get";
|
||||
b.data.(n)
|
||||
let set b n x =
|
||||
if n >= size b then invalid_arg "Buffer.set";
|
||||
b.data.(n) <- x
|
||||
let push b x =
|
||||
let cap = Array.length b.data in
|
||||
if size b = cap then begin
|
||||
let data = mk_array (2 * cap + 1) in
|
||||
Array.blit b.data 0 data 0 cap;
|
||||
b.data <- data
|
||||
end;
|
||||
let sz = size b in
|
||||
b.size <- sz + 1;
|
||||
set b sz x
|
||||
end
|
||||
|
||||
let binop_state n op s1 s2 =
|
||||
let key = K (op, s1, s2) in
|
||||
try StateMap.find key n.statemap
|
||||
with Not_found -> atom_state n Tmp
|
||||
|
||||
type id = int
|
||||
type term_data =
|
||||
| Binop of op * id * id
|
||||
| Leaf of atomic_pattern
|
||||
type term =
|
||||
{ id: id
|
||||
; data: term_data
|
||||
; state: p state }
|
||||
|
||||
let pp_term fmt (ta, id) =
|
||||
let fpf x = Format.fprintf fmt x in
|
||||
let rec pp _fmt id =
|
||||
match ta.(id).data with
|
||||
| Leaf (Con c) -> fpf "%Ld" c
|
||||
| Leaf AnyCon -> fpf "$%d" id
|
||||
| Leaf Tmp -> fpf "%%%d" id
|
||||
| Binop (op, id1, id2) ->
|
||||
fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
|
||||
(show_op op) id ta.(id).state.id
|
||||
pp id1 pp id2
|
||||
in pp fmt id
|
||||
|
||||
(* A term pool is a deduplicated set of term
|
||||
* that maintains nodes numbering using the
|
||||
* statemap passed at creation time *)
|
||||
module TermPool = struct
|
||||
type t =
|
||||
{ terms: term Buffer.t
|
||||
; hcons: (term_data, id) Hashtbl.t
|
||||
; numbr: numberer }
|
||||
|
||||
let create numbr =
|
||||
{ terms = Buffer.create ()
|
||||
; hcons = Hashtbl.create 100
|
||||
; numbr }
|
||||
let reset tp =
|
||||
Buffer.reset tp.terms;
|
||||
Hashtbl.clear tp.hcons
|
||||
|
||||
let size tp = Buffer.size tp.terms
|
||||
let term tp id = Buffer.get tp.terms id
|
||||
|
||||
let mk_leaf tp atm =
|
||||
let data = Leaf atm in
|
||||
match Hashtbl.find tp.hcons data with
|
||||
| id -> term tp id
|
||||
| exception Not_found ->
|
||||
let id = Buffer.size tp.terms in
|
||||
let state = atom_state tp.numbr atm in
|
||||
Buffer.push tp.terms {id; data; state};
|
||||
Hashtbl.add tp.hcons data id;
|
||||
term tp id
|
||||
let mk_binop tp op t1 t2 =
|
||||
let data = Binop (op, t1.id, t2.id) in
|
||||
match Hashtbl.find tp.hcons data with
|
||||
| id -> term tp id
|
||||
| exception Not_found ->
|
||||
let id = Buffer.size tp.terms in
|
||||
let state =
|
||||
binop_state tp.numbr op t1.state t2.state
|
||||
in
|
||||
Buffer.push tp.terms {id; data; state};
|
||||
Hashtbl.add tp.hcons data id;
|
||||
term tp id
|
||||
|
||||
let rec add_pattern tp = function
|
||||
| Bnr (op, p1, p2) ->
|
||||
let t1 = add_pattern tp p1 in
|
||||
let t2 = add_pattern tp p2 in
|
||||
mk_binop tp op t1 t2
|
||||
| Atm atm -> mk_leaf tp atm
|
||||
| Var (_, atm) -> add_pattern tp (Atm atm)
|
||||
|
||||
let explode_term tp id =
|
||||
let rec aux tms n id =
|
||||
let t = term tp id in
|
||||
match t.data with
|
||||
| Leaf _ -> (n, {t with id = n} :: tms)
|
||||
| Binop (op, id1, id2) ->
|
||||
let n1, tms = aux tms n id1 in
|
||||
let n = n1 + 1 in
|
||||
let n2, tms = aux tms n id2 in
|
||||
let n = n2 + 1 in
|
||||
(n, { t with data = Binop (op, n1, n2)
|
||||
; id = n } :: tms)
|
||||
in
|
||||
let n, tms = aux [] 0 id in
|
||||
Array.of_list (List.rev tms), n
|
||||
end
|
||||
|
||||
module R = Random
|
||||
|
||||
(* uniform pick in a list *)
|
||||
let list_pick l =
|
||||
let rec aux n l x =
|
||||
match l with
|
||||
| [] -> x
|
||||
| y :: l ->
|
||||
if R.int (n + 1) = 0 then
|
||||
aux (n + 1) l y
|
||||
else
|
||||
aux (n + 1) l x
|
||||
in
|
||||
match l with
|
||||
| [] -> invalid_arg "list_pick"
|
||||
| x :: l -> aux 1 l x
|
||||
|
||||
let term_pick ~numbr =
|
||||
let ops =
|
||||
if numbr.ops = [] then
|
||||
numbr.ops <-
|
||||
(StateMap.fold (fun k _ ops ->
|
||||
match k with
|
||||
| K (op, _, _) -> op :: ops)
|
||||
numbr.statemap [] |> setify);
|
||||
numbr.ops
|
||||
in
|
||||
let rec gen depth =
|
||||
(* exponential probability for leaves to
|
||||
* avoid skewing towards shallow terms *)
|
||||
let atm_prob = 0.75 ** float_of_int depth in
|
||||
if R.float 1.0 <= atm_prob || ops = [] then
|
||||
let atom, st = list_pick numbr.atoms in
|
||||
(st, Atm atom)
|
||||
else
|
||||
let op = list_pick ops in
|
||||
let s1, t1 = gen (depth - 1) in
|
||||
let s2, t2 = gen (depth - 1) in
|
||||
( binop_state numbr op s1 s2
|
||||
, Bnr (op, t1, t2) )
|
||||
in fun ~depth -> gen depth
|
||||
|
||||
exception FuzzError
|
||||
|
||||
let rec pattern_depth = function
|
||||
| Bnr (_, p1, p2) ->
|
||||
1 + max (pattern_depth p1) (pattern_depth p2)
|
||||
| Atm _ -> 0
|
||||
| Var (_, atm) -> pattern_depth (Atm atm)
|
||||
|
||||
let ( %% ) a b =
|
||||
1e2 *. float_of_int a /. float_of_int b
|
||||
|
||||
let progress ?(width = 50) msg pct =
|
||||
Format.eprintf "\x1b[2K\r%!";
|
||||
let progress_bar fmt =
|
||||
let n =
|
||||
let fwidth = float_of_int width in
|
||||
1 + int_of_float (pct *. fwidth /. 1e2)
|
||||
in
|
||||
Format.fprintf fmt " %s%s %.0f%%@?"
|
||||
(String.concat "" (List.init n (fun _ -> "▒")))
|
||||
(String.make (max 0 (width - n)) '-')
|
||||
pct
|
||||
in
|
||||
Format.kfprintf progress_bar
|
||||
Format.err_formatter msg
|
||||
|
||||
let fuzz_numberer rules numbr =
|
||||
(* pick twice the max pattern depth so we
|
||||
* have a chance to find non-trivial numbers
|
||||
* for the atomic patterns in the rules *)
|
||||
let depth =
|
||||
List.fold_left (fun depth r ->
|
||||
max depth (pattern_depth r.pattern))
|
||||
0 rules * 2
|
||||
in
|
||||
(* fuzz until the term pool we are constructing
|
||||
* is no longer growing fast enough; or we just
|
||||
* went through sufficiently many iterations *)
|
||||
let max_iter = 1_000_000 in
|
||||
let low_insert_rate = 1e-2 in
|
||||
let tp = TermPool.create numbr in
|
||||
let rec loop new_stats i =
|
||||
let (_, _, insert_rate) = new_stats in
|
||||
if insert_rate <= low_insert_rate then () else
|
||||
if i >= max_iter then () else
|
||||
(* periodically update stats *)
|
||||
let new_stats =
|
||||
let (num, cnt, rate) = new_stats in
|
||||
if num land 1023 = 0 then
|
||||
let rate =
|
||||
0.5 *. (rate +. float_of_int cnt /. 1023.)
|
||||
in
|
||||
progress " insert_rate=%.1f%%"
|
||||
(i %% max_iter) (rate *. 1e2);
|
||||
(num + 1, 0, rate)
|
||||
else new_stats
|
||||
in
|
||||
(* create a term and check that its number is
|
||||
* accurate wrt the rules *)
|
||||
let st, term = term_pick ~numbr ~depth in
|
||||
let state_matched =
|
||||
List.filter_map (fun cu ->
|
||||
match cu with
|
||||
| Top ("$" | "%") -> None
|
||||
| Top name -> Some name
|
||||
| _ -> None)
|
||||
st.point |> setify
|
||||
in
|
||||
let rule_matched =
|
||||
List.filter_map (fun r ->
|
||||
if pattern_match r.pattern term then
|
||||
Some r.name
|
||||
else None)
|
||||
rules |> setify
|
||||
in
|
||||
if state_matched <> rule_matched then begin
|
||||
let open Format in
|
||||
let pp_str_list =
|
||||
let pp_sep fmt () = fprintf fmt ",@ " in
|
||||
pp_print_list ~pp_sep pp_print_string
|
||||
in
|
||||
eprintf "@.@[<v2>fuzz error for %s"
|
||||
(show_pattern term);
|
||||
eprintf "@ @[state matched: %a@]"
|
||||
pp_str_list state_matched;
|
||||
eprintf "@ @[rule matched: %a@]"
|
||||
pp_str_list rule_matched;
|
||||
eprintf "@]@.";
|
||||
raise FuzzError;
|
||||
end;
|
||||
if state_matched = [] then
|
||||
loop new_stats (i + 1)
|
||||
else
|
||||
(* add to the term pool *)
|
||||
let old_size = TermPool.size tp in
|
||||
let _ = TermPool.add_pattern tp term in
|
||||
let new_stats =
|
||||
let (num, cnt, rate) = new_stats in
|
||||
if TermPool.size tp <> old_size then
|
||||
(num + 1, cnt + 1, rate)
|
||||
else
|
||||
(num + 1, cnt, rate)
|
||||
in
|
||||
loop new_stats (i + 1)
|
||||
in
|
||||
loop (1, 0, 1.0) 0;
|
||||
Format.eprintf
|
||||
"@.@[ generated %.3fMiB of test terms@]@."
|
||||
(float_of_int (Obj.reachable_words (Obj.repr tp))
|
||||
/. 128. /. 1024.);
|
||||
tp
|
||||
|
||||
let rec run_matcher stk m (ta, id as t) =
|
||||
let state id = ta.(id).state.id in
|
||||
match m.Action.node with
|
||||
| Action.Switch cases ->
|
||||
let m =
|
||||
try List.assoc (state id) cases
|
||||
with Not_found -> failwith "no switch case"
|
||||
in
|
||||
run_matcher stk m t
|
||||
| Action.Push (sym, m) ->
|
||||
let l, r =
|
||||
match ta.(id).data with
|
||||
| Leaf _ -> failwith "push on leaf"
|
||||
| Binop (_, l, r) -> (l, r)
|
||||
in
|
||||
if sym && state l > state r
|
||||
then run_matcher (l :: stk) m (ta, r)
|
||||
else run_matcher (r :: stk) m (ta, l)
|
||||
| Action.Pop m -> begin
|
||||
match stk with
|
||||
| id :: stk -> run_matcher stk m (ta, id)
|
||||
| [] -> failwith "pop on empty stack"
|
||||
end
|
||||
| Action.Set (v, m) ->
|
||||
(v, id) :: run_matcher stk m t
|
||||
| Action.Stop -> []
|
||||
|
||||
let rec term_match p (ta, id) =
|
||||
let (|>>) x f =
|
||||
match x with None -> None | Some x -> f x
|
||||
in
|
||||
let atom_match a =
|
||||
match ta.(id).data with
|
||||
| Leaf a' -> pattern_match (Atm a) (Atm a')
|
||||
| Binop _ -> pattern_match (Atm a) (Atm Tmp)
|
||||
in
|
||||
match p with
|
||||
| Var (v, a) when atom_match a ->
|
||||
Some [(v, id)]
|
||||
| Atm a when atom_match a -> Some []
|
||||
| (Atm _ | Var _) -> None
|
||||
| Bnr (op, pl, pr) -> begin
|
||||
match ta.(id).data with
|
||||
| Binop (op', idl, idr) when op' = op ->
|
||||
term_match pl (ta, idl) |>> fun l1 ->
|
||||
term_match pr (ta, idr) |>> fun l2 ->
|
||||
Some (l1 @ l2)
|
||||
| _ -> None
|
||||
end
|
||||
|
||||
let test_matchers tp numbr rules =
|
||||
let {statemap = sm; states = sa; _} = numbr in
|
||||
let total = ref 0 in
|
||||
let matchers =
|
||||
let htbl = Hashtbl.create (Array.length sa) in
|
||||
List.map (fun r -> (r.name, r.pattern)) rules |>
|
||||
group_by_fst |>
|
||||
List.iter (fun (r, ps) ->
|
||||
total := !total + List.length ps;
|
||||
let pm = (ps, lr_matcher sm sa rules r) in
|
||||
sa |> Array.iter (fun s ->
|
||||
if List.mem (Top r) s.point then
|
||||
Hashtbl.add htbl s.id pm));
|
||||
htbl
|
||||
in
|
||||
let seen = Hashtbl.create !total in
|
||||
for id = 0 to TermPool.size tp - 1 do
|
||||
if id land 1023 = 0 ||
|
||||
id = TermPool.size tp - 1 then begin
|
||||
progress
|
||||
" coverage=%.1f%%"
|
||||
(id %% TermPool.size tp)
|
||||
(Hashtbl.length seen %% !total)
|
||||
end;
|
||||
let t = TermPool.explode_term tp id in
|
||||
Hashtbl.find_all matchers
|
||||
(TermPool.term tp id).state.id |>
|
||||
List.iter (fun (ps, m) ->
|
||||
let norm = List.fast_sort compare in
|
||||
let ok =
|
||||
match norm (run_matcher [] m t) with
|
||||
| asn -> `Match (List.exists (fun p ->
|
||||
match term_match p t with
|
||||
| None -> false
|
||||
| Some asn' ->
|
||||
if asn = norm asn' then begin
|
||||
Hashtbl.replace seen p ();
|
||||
true
|
||||
end else false) ps)
|
||||
| exception e -> `RunFailure e
|
||||
in
|
||||
if ok <> `Match true then begin
|
||||
let open Format in
|
||||
let pp_asn fmt asn =
|
||||
fprintf fmt "@[<h>";
|
||||
pp_print_list
|
||||
~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
|
||||
(fun fmt (v, d) ->
|
||||
fprintf fmt "@[%s←%d@]" v d)
|
||||
fmt asn;
|
||||
fprintf fmt "@]"
|
||||
in
|
||||
eprintf "@.@[<v2>matcher error for";
|
||||
eprintf "@ @[%a@]" pp_term t;
|
||||
begin match ok with
|
||||
| `RunFailure e ->
|
||||
eprintf "@ @[exception: %s@]"
|
||||
(Printexc.to_string e)
|
||||
| `Match (* false *) _ ->
|
||||
let asn = run_matcher [] m t in
|
||||
eprintf "@ @[assignment: %a@]"
|
||||
pp_asn asn;
|
||||
eprintf "@ @[<v2>could not match";
|
||||
List.iter (fun p ->
|
||||
eprintf "@ + @[%s@]"
|
||||
(show_pattern p)) ps;
|
||||
eprintf "@]"
|
||||
end;
|
||||
eprintf "@]@.";
|
||||
raise FuzzError
|
||||
end)
|
||||
done;
|
||||
Format.eprintf "@."
|
||||
|
||||
|
214
tools/mgen/main.ml
Normal file
214
tools/mgen/main.ml
Normal file
|
@ -0,0 +1,214 @@
|
|||
open Cgen
|
||||
open Match
|
||||
|
||||
let mgen ~verbose ~fuzz path lofs input oc =
|
||||
let info ?(level = 1) fmt =
|
||||
if level <= verbose then
|
||||
Printf.eprintf fmt
|
||||
else
|
||||
Printf.ifprintf stdout fmt
|
||||
in
|
||||
|
||||
let rules =
|
||||
match Sexp.(run_parser ppats) input with
|
||||
| `Error (ps, err, loc) ->
|
||||
Printf.eprintf "%s:%d:%d %s\n"
|
||||
path (lofs + ps.Sexp.line) ps.Sexp.coln err;
|
||||
Printf.eprintf "%s" loc;
|
||||
exit 1
|
||||
| `Ok rules -> rules
|
||||
in
|
||||
|
||||
info "adding ac variants...%!";
|
||||
let nparsed =
|
||||
List.fold_left
|
||||
(fun npats (_, _, ps) ->
|
||||
npats + List.length ps)
|
||||
0 rules
|
||||
in
|
||||
let varsmap = Hashtbl.create 10 in
|
||||
let rules =
|
||||
List.concat_map (fun (name, vars, patterns) ->
|
||||
(try assert (Hashtbl.find varsmap name = vars)
|
||||
with Not_found -> ());
|
||||
Hashtbl.replace varsmap name vars;
|
||||
List.map
|
||||
(fun pattern -> {name; vars; pattern})
|
||||
(List.concat_map ac_equiv patterns)
|
||||
) rules
|
||||
in
|
||||
info " %d -> %d patterns\n"
|
||||
nparsed (List.length rules);
|
||||
|
||||
let rnames =
|
||||
setify (List.map (fun r -> r.name) rules) in
|
||||
|
||||
info "generating match tables...%!";
|
||||
let sa, am, sm = generate_table rules in
|
||||
let numbr = make_numberer sa am sm in
|
||||
info " %d states, %d rules\n"
|
||||
(Array.length sa) (StateMap.cardinal sm);
|
||||
if verbose >= 2 then begin
|
||||
info "-------------\nstates:\n";
|
||||
Array.iteri (fun i s ->
|
||||
info " state %d: %s\n"
|
||||
i (show_pattern s.seen)) sa;
|
||||
info "-------------\nstatemap:\n";
|
||||
Test.print_sm stderr sm;
|
||||
info "-------------\n";
|
||||
end;
|
||||
|
||||
info "generating matchers...\n";
|
||||
let matchers =
|
||||
List.map (fun rname ->
|
||||
info "+ %s...%!" rname;
|
||||
let m = lr_matcher sm sa rules rname in
|
||||
let vars = Hashtbl.find varsmap rname in
|
||||
info " %d nodes\n" (Action.size m);
|
||||
info ~level:2 " -------------\n";
|
||||
info ~level:2 " automaton:\n";
|
||||
info ~level:2 "%s\n"
|
||||
(Format.asprintf " @[%a@]" Action.pp m);
|
||||
info ~level:2 " ----------\n";
|
||||
(vars, rname, m)
|
||||
) rnames
|
||||
in
|
||||
|
||||
if fuzz then begin
|
||||
info ~level:0 "fuzzing statemap...\n";
|
||||
let tp = Fuzz.fuzz_numberer rules numbr in
|
||||
info ~level:0 "testing %d patterns...\n"
|
||||
(List.length rules);
|
||||
Fuzz.test_matchers tp numbr rules
|
||||
end;
|
||||
|
||||
info "emitting C...\n";
|
||||
flush stderr;
|
||||
|
||||
let cgopts =
|
||||
{ pfx = ""; static = true; oc = oc } in
|
||||
emit_c cgopts numbr;
|
||||
emit_matchers cgopts matchers;
|
||||
|
||||
()
|
||||
|
||||
let read_all ic =
|
||||
let bufsz = 4096 in
|
||||
let buf = Bytes.create bufsz in
|
||||
let data = Buffer.create bufsz in
|
||||
let read = ref 0 in
|
||||
while
|
||||
read := input ic buf 0 bufsz;
|
||||
!read <> 0
|
||||
do
|
||||
Buffer.add_subbytes data buf 0 !read
|
||||
done;
|
||||
Buffer.contents data
|
||||
|
||||
let split_c src =
|
||||
let begin_re, eoc_re, end_re =
|
||||
let re = Str.regexp in
|
||||
( re "mgen generated code"
|
||||
, re "\\*/"
|
||||
, re "end of generated code" )
|
||||
in
|
||||
let str_match regexp str =
|
||||
try
|
||||
let _: int =
|
||||
Str.search_forward regexp str 0
|
||||
in true
|
||||
with Not_found -> false
|
||||
in
|
||||
|
||||
let rec go st lofs pfx rules lines =
|
||||
let line, lines =
|
||||
match lines with
|
||||
| [] ->
|
||||
failwith (
|
||||
match st with
|
||||
| `Prefix -> "could not find mgen section"
|
||||
| `Rules -> "mgen rules not terminated"
|
||||
| `Skip -> "mgen section not terminated"
|
||||
)
|
||||
| l :: ls -> (l, ls)
|
||||
in
|
||||
match st with
|
||||
| `Prefix ->
|
||||
let pfx = line :: pfx in
|
||||
if str_match begin_re line
|
||||
then
|
||||
let lofs = List.length pfx in
|
||||
go `Rules lofs pfx rules lines
|
||||
else go `Prefix 0 pfx rules lines
|
||||
| `Rules ->
|
||||
let pfx = line :: pfx in
|
||||
if str_match eoc_re line
|
||||
then go `Skip lofs pfx rules lines
|
||||
else go `Rules lofs pfx (line :: rules) lines
|
||||
| `Skip ->
|
||||
if str_match end_re line then
|
||||
let join = String.concat "\n" in
|
||||
let pfx = join (List.rev pfx) ^ "\n\n"
|
||||
and rules = join (List.rev rules)
|
||||
and sfx = join (line :: lines)
|
||||
in (lofs, pfx, rules, sfx)
|
||||
else go `Skip lofs pfx rules lines
|
||||
in
|
||||
|
||||
let lines = String.split_on_char '\n' src in
|
||||
go `Prefix 0 [] [] lines
|
||||
|
||||
let () =
|
||||
let usage_msg =
|
||||
"mgen [--fuzz] [--verbose <N>] <file>" in
|
||||
|
||||
let fuzz_arg = ref false in
|
||||
let verbose_arg = ref 0 in
|
||||
let input_paths = ref [] in
|
||||
|
||||
let anon_fun filename =
|
||||
input_paths := filename :: !input_paths in
|
||||
|
||||
let speclist =
|
||||
[ ( "--fuzz", Arg.Set fuzz_arg
|
||||
, " Fuzz tables and matchers" )
|
||||
; ( "--verbose", Arg.Set_int verbose_arg
|
||||
, "<N> Set verbosity level" )
|
||||
; ( "--", Arg.Rest_all (List.iter anon_fun)
|
||||
, " Stop argument parsing" ) ]
|
||||
in
|
||||
Arg.parse speclist anon_fun usage_msg;
|
||||
|
||||
let input_paths = !input_paths in
|
||||
let verbose = !verbose_arg in
|
||||
let fuzz = !fuzz_arg in
|
||||
let input_path, input =
|
||||
match input_paths with
|
||||
| ["-"] -> ("-", read_all stdin)
|
||||
| [path] -> (path, read_all (open_in path))
|
||||
| _ ->
|
||||
Printf.eprintf
|
||||
"%s: single input file expected\n"
|
||||
Sys.argv.(0);
|
||||
Arg.usage speclist usage_msg; exit 1
|
||||
in
|
||||
let mgen = mgen ~verbose ~fuzz in
|
||||
|
||||
if Str.last_chars input_path 2 <> ".c"
|
||||
then mgen input_path 0 input stdout
|
||||
else
|
||||
let tmp_path = input_path ^ ".tmp" in
|
||||
Fun.protect
|
||||
~finally:(fun () ->
|
||||
try Sys.remove tmp_path with _ -> ())
|
||||
(fun () ->
|
||||
let lofs, pfx, rules, sfx = split_c input in
|
||||
let oc = open_out tmp_path in
|
||||
output_string oc pfx;
|
||||
mgen input_path lofs rules oc;
|
||||
output_string oc sfx;
|
||||
close_out oc;
|
||||
Sys.rename tmp_path input_path;
|
||||
());
|
||||
|
||||
()
|
651
tools/mgen/match.ml
Normal file
651
tools/mgen/match.ml
Normal file
|
@ -0,0 +1,651 @@
|
|||
type cls = Kw | Kl | Ks | Kd
|
||||
type op_base =
|
||||
| Oadd
|
||||
| Osub
|
||||
| Omul
|
||||
| Oor
|
||||
| Oshl
|
||||
| Oshr
|
||||
type op = cls * op_base
|
||||
|
||||
let op_bases =
|
||||
[Oadd; Osub; Omul; Oor; Oshl; Oshr]
|
||||
|
||||
let commutative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
let associative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
type atomic_pattern =
|
||||
| Tmp
|
||||
| AnyCon
|
||||
| Con of int64
|
||||
(* Tmp < AnyCon < Con k *)
|
||||
|
||||
type pattern =
|
||||
| Bnr of op * pattern * pattern
|
||||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let is_atomic = function
|
||||
| (Atm _ | Var _) -> true
|
||||
| _ -> false
|
||||
|
||||
let show_op_base o =
|
||||
match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul"
|
||||
| Oor -> "or"
|
||||
| Oshl -> "shl"
|
||||
| Oshr -> "shr"
|
||||
|
||||
let show_op (k, o) =
|
||||
show_op_base o ^
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
|
||||
let rec show_pattern p =
|
||||
match p with
|
||||
| Atm Tmp -> "%"
|
||||
| Atm AnyCon -> "$"
|
||||
| Atm (Con n) -> Int64.to_string n
|
||||
| Var (v, p) ->
|
||||
show_pattern (Atm p) ^ "'" ^ v
|
||||
| Bnr (o, pl, pr) ->
|
||||
"(" ^ show_op o ^
|
||||
" " ^ show_pattern pl ^
|
||||
" " ^ show_pattern pr ^ ")"
|
||||
|
||||
let get_atomic p =
|
||||
match p with
|
||||
| (Atm a | Var (_, a)) -> Some a
|
||||
| _ -> None
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Var (_, p) ->
|
||||
pattern_match (Atm p) w
|
||||
| Atm Tmp ->
|
||||
begin match get_atomic w with
|
||||
| Some (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
end
|
||||
| Atm (Con _) -> w = p
|
||||
| Atm (AnyCon) ->
|
||||
not (pattern_match (Atm Tmp) w)
|
||||
| Bnr (o, pl, pr) ->
|
||||
begin match w with
|
||||
| Bnr (o', wl, wr) ->
|
||||
o' = o &&
|
||||
pattern_match pl wl &&
|
||||
pattern_match pr wr
|
||||
| _ -> false
|
||||
end
|
||||
|
||||
type +'a cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * 'a cursor * pattern
|
||||
| Bnrr of op * pattern * 'a cursor
|
||||
| Top of 'a
|
||||
|
||||
let rec fold_cursor c p =
|
||||
match c with
|
||||
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
|
||||
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
|
||||
| Top _ -> p
|
||||
|
||||
let peel p x =
|
||||
let once out (p, c) =
|
||||
match p with
|
||||
| Var (_, p) -> (Atm p, c) :: out
|
||||
| Atm _ -> (p, c) :: out
|
||||
| Bnr (o, pl, pr) ->
|
||||
(pl, Bnrl (o, c, pr)) ::
|
||||
(pr, Bnrr (o, pl, c)) :: out
|
||||
in
|
||||
let rec go l =
|
||||
let l' = List.fold_left once [] l in
|
||||
if List.length l' = List.length l
|
||||
then l'
|
||||
else go l'
|
||||
in go [(p, Top x)]
|
||||
|
||||
let fold_pairs l1 l2 ini f =
|
||||
let rec go acc = function
|
||||
| [] -> acc
|
||||
| a :: l1' ->
|
||||
go (List.fold_left
|
||||
(fun acc b -> f (a, b) acc)
|
||||
acc l2) l1'
|
||||
in go ini l1
|
||||
|
||||
let iter_pairs l f =
|
||||
fold_pairs l l () (fun x () -> f x)
|
||||
|
||||
let inverse l =
|
||||
List.map (fun (a, b) -> (b, a)) l
|
||||
|
||||
type 'a state =
|
||||
{ id: int
|
||||
; seen: pattern
|
||||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.filter_map (fun c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||
| _ -> None)
|
||||
point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
compare a b) l |>
|
||||
List.fold_left (fun (oo, l, res) (o', c) ->
|
||||
match oo with
|
||||
| None -> (Some o', [c], [])
|
||||
| Some o when o = o' -> (oo, c :: l, res)
|
||||
| Some o -> (Some o', [c], (o, l) :: res))
|
||||
(None, [], []) |>
|
||||
(function
|
||||
| (None, _, _) -> []
|
||||
| (Some o, l, res) -> (o, l) :: res)
|
||||
|
||||
let sort_uniq cmp l =
|
||||
List.fast_sort cmp l |>
|
||||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e when cmp e e' = 0 -> (eo, l)
|
||||
| Some e -> (Some e', e :: l))
|
||||
(None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
||||
let setify l =
|
||||
sort_uniq compare l
|
||||
|
||||
let normalize (point: ('a cursor) list) =
|
||||
setify point
|
||||
|
||||
let next_binary tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst in
|
||||
let o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = -1
|
||||
; seen = Bnr (o, s1.seen, s2.seen)
|
||||
; point = normalize (l @ tmp) })
|
||||
(group_by_fst (o1 @ o2))
|
||||
|
||||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type t
|
||||
val create: unit -> t
|
||||
val add: t -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: t -> (p state -> unit) -> unit
|
||||
val elems: t -> (p state) list
|
||||
end = struct
|
||||
open Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type nonrec t =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 0 }
|
||||
let add set s =
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
`Found, {s with id}
|
||||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
let iter set f =
|
||||
let f s id = f {s with id} in
|
||||
iter f set.h
|
||||
let elems set =
|
||||
let res = ref [] in
|
||||
iter set (fun s -> res := s :: !res);
|
||||
!res
|
||||
end
|
||||
|
||||
type table_key =
|
||||
| K of op * p state * p state
|
||||
|
||||
module StateMap = struct
|
||||
include Map.Make(struct
|
||||
type t = table_key
|
||||
let compare ka kb =
|
||||
match ka, kb with
|
||||
| K (o, sl, sr), K (o', sl', sr') ->
|
||||
compare (o, sl.id, sr.id)
|
||||
(o', sl'.id, sr'.id)
|
||||
end)
|
||||
let invert n sm =
|
||||
let rmap = Array.make n [] in
|
||||
iter (fun k {id; _} ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
rmap.(id) <-
|
||||
(o, (sl.id, sr.id)) :: rmap.(id)
|
||||
) sm;
|
||||
Array.map group_by_fst rmap
|
||||
let by_ops sm =
|
||||
fold (fun tk s ops ->
|
||||
match tk with
|
||||
| K (op, l, r) ->
|
||||
(op, ((l.id, r.id), s.id)) :: ops)
|
||||
sm [] |> group_by_fst
|
||||
end
|
||||
|
||||
type rule =
|
||||
{ name: string
|
||||
; vars: string list
|
||||
; pattern: pattern }
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
let rl =
|
||||
(* these atomic patterns must occur in
|
||||
* rules so that we are able to number
|
||||
* all possible refs *)
|
||||
[ { name = "$"; vars = []
|
||||
; pattern = Atm AnyCon }
|
||||
; { name = "%"; vars = []
|
||||
; pattern = Atm Tmp } ] @ rl
|
||||
in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.concat_map
|
||||
(fun r -> peel r.pattern r.name) rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let tmp = List.assoc (Atm Tmp) ground in
|
||||
let con = List.assoc (Atm AnyCon) ground in
|
||||
let atoms = ref [] in
|
||||
let () =
|
||||
List.iter (fun (seen, l) ->
|
||||
let point =
|
||||
if pattern_match (Atm Tmp) seen
|
||||
then normalize (tmp @ l)
|
||||
else normalize (con @ l)
|
||||
in
|
||||
let s = {id = -1; seen; point} in
|
||||
let _, s = StateSet.add states s in
|
||||
match get_atomic seen with
|
||||
| Some atm -> atoms := (atm, s) :: !atoms
|
||||
| None -> ()
|
||||
) ground
|
||||
in
|
||||
(* setup loop state *)
|
||||
let map = ref StateMap.empty in
|
||||
let map_add k s' =
|
||||
map := StateMap.add k s' !map
|
||||
in
|
||||
let flag = ref `Added in
|
||||
let flagmerge = function
|
||||
| `Added -> flag := `Added
|
||||
| _ -> ()
|
||||
in
|
||||
(* iterate until fixpoint *)
|
||||
while !flag = `Added do
|
||||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
next_binary tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
flagmerge flag';
|
||||
map_add (K (o, sl, sr)) s';
|
||||
));
|
||||
done;
|
||||
let states =
|
||||
StateSet.elems states |>
|
||||
List.sort (fun s s' -> compare s.id s'.id) |>
|
||||
Array.of_list
|
||||
in
|
||||
(states, !atoms, !map)
|
||||
|
||||
let intersperse x l =
|
||||
let rec go left right out =
|
||||
let out =
|
||||
(List.rev left @ [x] @ right) ::
|
||||
out in
|
||||
match right with
|
||||
| x :: right' ->
|
||||
go (x :: left) right' out
|
||||
| [] -> out
|
||||
in go [] l []
|
||||
|
||||
let rec permute = function
|
||||
| [] -> [[]]
|
||||
| x :: l ->
|
||||
List.concat (List.map
|
||||
(intersperse x) (permute l))
|
||||
|
||||
(* build all binary trees with ordered
|
||||
* leaves l *)
|
||||
let rec bins build l =
|
||||
let rec go l r out =
|
||||
match r with
|
||||
| [] -> out
|
||||
| x :: r' ->
|
||||
go (l @ [x]) r'
|
||||
(fold_pairs
|
||||
(bins build l)
|
||||
(bins build r)
|
||||
out (fun (l, r) out ->
|
||||
build l r :: out))
|
||||
in
|
||||
match l with
|
||||
| [] -> []
|
||||
| [x] -> [x]
|
||||
| x :: l -> go [x] l []
|
||||
|
||||
let products l ini f =
|
||||
let rec go acc la = function
|
||||
| [] -> f (List.rev la) acc
|
||||
| xs :: l ->
|
||||
List.fold_left (fun acc x ->
|
||||
go acc (x :: la) l)
|
||||
acc xs
|
||||
in go ini [] l
|
||||
|
||||
(* combinatorial nuke... *)
|
||||
let rec ac_equiv =
|
||||
let rec alevel o = function
|
||||
| Bnr (o', l, r) when o' = o ->
|
||||
alevel o l @ alevel o r
|
||||
| x -> [x]
|
||||
in function
|
||||
| Bnr (o, _, _) as p
|
||||
when associative o ->
|
||||
products
|
||||
(List.map ac_equiv (alevel o p)) []
|
||||
(fun choice out ->
|
||||
List.concat_map
|
||||
(bins (fun l r -> Bnr (o, l, r)))
|
||||
(if commutative o
|
||||
then permute choice
|
||||
else [choice]) @ out)
|
||||
| Bnr (o, l, r)
|
||||
when commutative o ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) ::
|
||||
Bnr (o, r, l) :: out)
|
||||
| Bnr (o, l, r) ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) :: out)
|
||||
| x -> [x]
|
||||
|
||||
module Action: sig
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t = private
|
||||
{ id: int; node: node }
|
||||
val equal: t -> t -> bool
|
||||
val size: t -> int
|
||||
val stop: t
|
||||
val mk_push: sym:bool -> t -> t
|
||||
val mk_pop: t -> t
|
||||
val mk_set: string -> t -> t
|
||||
val mk_switch: int list -> (int -> t) -> t
|
||||
val pp: Format.formatter -> t -> unit
|
||||
end = struct
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t =
|
||||
{ id: int; node: node }
|
||||
|
||||
let equal a a' = a.id = a'.id
|
||||
let size a =
|
||||
let seen = Hashtbl.create 10 in
|
||||
let rec node_size = function
|
||||
| Switch l ->
|
||||
List.fold_left
|
||||
(fun n (_, a) -> n + size a) 0 l
|
||||
| (Push (_, a) | Pop a | Set (_, a)) ->
|
||||
size a
|
||||
| Stop -> 0
|
||||
and size {id; node} =
|
||||
if Hashtbl.mem seen id
|
||||
then 0
|
||||
else begin
|
||||
Hashtbl.add seen id ();
|
||||
1 + node_size node
|
||||
end
|
||||
in
|
||||
size a
|
||||
|
||||
let mk =
|
||||
let hcons = Hashtbl.create 100 in
|
||||
let fresh = ref 0 in
|
||||
fun node ->
|
||||
let id =
|
||||
try Hashtbl.find hcons node
|
||||
with Not_found ->
|
||||
let id = !fresh in
|
||||
Hashtbl.add hcons node id;
|
||||
fresh := id + 1;
|
||||
id
|
||||
in
|
||||
{id; node}
|
||||
let stop = mk Stop
|
||||
let mk_push ~sym a = mk (Push (sym, a))
|
||||
let mk_pop a =
|
||||
match a.node with
|
||||
| Stop -> a
|
||||
| _ -> mk (Pop a)
|
||||
let mk_set v a = mk (Set (v, a))
|
||||
let mk_switch ids f =
|
||||
match List.map f ids with
|
||||
| [] -> failwith "empty switch";
|
||||
| c :: cs as cases ->
|
||||
if List.for_all (equal c) cs then c
|
||||
else
|
||||
let cases = List.combine ids cases in
|
||||
mk (Switch cases)
|
||||
|
||||
open Format
|
||||
let rec pp_node fmt = function
|
||||
| Switch l ->
|
||||
fprintf fmt "@[<v>@[<v2>switch{";
|
||||
let pp_case (c, a) =
|
||||
let pp_sep fmt () = fprintf fmt "," in
|
||||
fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
|
||||
(pp_print_list ~pp_sep pp_print_int)
|
||||
c pp a
|
||||
in
|
||||
inverse l |> group_by_fst |> inverse |>
|
||||
List.iter pp_case;
|
||||
fprintf fmt "@]@,}@]"
|
||||
| Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
|
||||
| Push (false, a) -> fprintf fmt "push@ %a" pp a
|
||||
| Pop a -> fprintf fmt "pop@ %a" pp a
|
||||
| Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
|
||||
| Stop -> fprintf fmt "•"
|
||||
and pp fmt a = pp_node fmt a.node
|
||||
end
|
||||
|
||||
(* a state is commutative if (a op b) enters
|
||||
* it iff (b op a) enters it as well *)
|
||||
let symmetric rmap id =
|
||||
List.for_all (fun (_, l) ->
|
||||
let l1, l2 =
|
||||
List.filter (fun (a, b) -> a <> b) l |>
|
||||
List.partition (fun (a, b) -> a < b)
|
||||
in
|
||||
setify l1 = setify (inverse l2))
|
||||
rmap.(id)
|
||||
|
||||
(* left-to-right matching of a set of patterns;
|
||||
* may raise if there is no lr matcher for the
|
||||
* input rule *)
|
||||
let lr_matcher statemap states rules name =
|
||||
let rmap =
|
||||
let nstates = Array.length states in
|
||||
StateMap.invert nstates statemap
|
||||
in
|
||||
let exception Stuck in
|
||||
(* the list of ids represents a class of terms
|
||||
* whose root ends up being labelled with one
|
||||
* such id; the gen function generates a matcher
|
||||
* that will, given any such term, assign values
|
||||
* for the Var nodes of one pattern in pats *)
|
||||
let rec gen
|
||||
: 'a. int list -> (pattern * 'a) list
|
||||
-> (int -> (pattern * 'a) list -> Action.t)
|
||||
-> Action.t
|
||||
= fun ids pats k ->
|
||||
Action.mk_switch (setify ids) @@ fun id_top ->
|
||||
let sym = symmetric rmap id_top in
|
||||
let id_ops =
|
||||
if sym then
|
||||
let ordered (a, b) = a <= b in
|
||||
List.map (fun (o, l) ->
|
||||
(o, List.filter ordered l))
|
||||
rmap.(id_top)
|
||||
else rmap.(id_top)
|
||||
in
|
||||
(* consider only the patterns that are
|
||||
* compatible with the current id *)
|
||||
let atm_pats, bin_pats =
|
||||
List.filter (function
|
||||
| Bnr (o, _, _), _ ->
|
||||
List.exists
|
||||
(fun (o', _) -> o' = o)
|
||||
id_ops
|
||||
| _ -> true) pats |>
|
||||
List.partition
|
||||
(fun (pat, _) -> is_atomic pat)
|
||||
in
|
||||
try
|
||||
if bin_pats = [] then raise Stuck;
|
||||
let pats_l =
|
||||
List.map (function
|
||||
| (Bnr (o, l, r), x) ->
|
||||
(l, (o, x, r))
|
||||
| _ -> assert false)
|
||||
bin_pats
|
||||
and pats_r =
|
||||
List.map (fun (l, (o, x, r)) ->
|
||||
(r, (o, l, x)))
|
||||
and patstop =
|
||||
List.map (fun (r, (o, l, x)) ->
|
||||
(Bnr (o, l, r), x))
|
||||
in
|
||||
let id_pairs = List.concat_map snd id_ops in
|
||||
let ids_l = List.map fst id_pairs
|
||||
and ids_r id_left =
|
||||
List.filter_map (fun (l, r) ->
|
||||
if l = id_left then Some r else None)
|
||||
id_pairs
|
||||
in
|
||||
(* match the left arm *)
|
||||
Action.mk_push ~sym
|
||||
(gen ids_l pats_l
|
||||
@@ fun lid pats ->
|
||||
(* then the right arm, considering
|
||||
* only the remaining possible
|
||||
* patterns and knowing that the
|
||||
* left arm was numbered 'lid' *)
|
||||
Action.mk_pop
|
||||
(gen (ids_r lid) (pats_r pats)
|
||||
@@ fun _rid pats ->
|
||||
(* continue with the parent *)
|
||||
k id_top (patstop pats)))
|
||||
with Stuck ->
|
||||
let atm_pats =
|
||||
let seen = states.(id_top).seen in
|
||||
List.filter (fun (pat, _) ->
|
||||
pattern_match pat seen) atm_pats
|
||||
in
|
||||
if atm_pats = [] then raise Stuck else
|
||||
let vars =
|
||||
List.filter_map (function
|
||||
| (Var (v, _), _) -> Some v
|
||||
| _ -> None) atm_pats |> setify
|
||||
in
|
||||
match vars with
|
||||
| [] -> k id_top atm_pats
|
||||
| [v] -> Action.mk_set v (k id_top atm_pats)
|
||||
| _ -> failwith "ambiguous var match"
|
||||
in
|
||||
(* generate a matcher for the rule *)
|
||||
let ids_top =
|
||||
Array.to_list states |>
|
||||
List.filter_map (fun {id; point = p; _} ->
|
||||
if List.exists ((=) (Top name)) p then
|
||||
Some id
|
||||
else None)
|
||||
in
|
||||
let rec filter_dups pats =
|
||||
match pats with
|
||||
| p :: pats ->
|
||||
if List.exists (pattern_match p) pats
|
||||
then filter_dups pats
|
||||
else p :: filter_dups pats
|
||||
| [] -> []
|
||||
in
|
||||
let pats_top =
|
||||
List.filter_map (fun r ->
|
||||
if r.name = name then
|
||||
Some r.pattern
|
||||
else None) rules |>
|
||||
filter_dups |>
|
||||
List.map (fun p -> (p, ()))
|
||||
in
|
||||
gen ids_top pats_top (fun _ pats ->
|
||||
assert (pats <> []);
|
||||
Action.stop)
|
||||
|
||||
type numberer =
|
||||
{ atoms: (atomic_pattern * p state) list
|
||||
; statemap: p state StateMap.t
|
||||
; states: p state array
|
||||
; mutable ops: op list
|
||||
(* memoizes the list of possible operations
|
||||
* according to the statemap *) }
|
||||
|
||||
let make_numberer sa am sm =
|
||||
{ atoms = am
|
||||
; states = sa
|
||||
; statemap = sm
|
||||
; ops = [] }
|
||||
|
||||
let atom_state n atm =
|
||||
List.assoc atm n.atoms
|
292
tools/mgen/sexp.ml
Normal file
292
tools/mgen/sexp.ml
Normal file
|
@ -0,0 +1,292 @@
|
|||
type pstate =
|
||||
{ data: string
|
||||
; line: int
|
||||
; coln: int
|
||||
; indx: int }
|
||||
|
||||
type perror =
|
||||
{ error: string
|
||||
; ps: pstate }
|
||||
|
||||
exception ParseError of perror
|
||||
|
||||
type 'a parser =
|
||||
{ fn: 'r. pstate -> ('a -> pstate -> 'r) -> 'r }
|
||||
|
||||
let update_pos ps beg fin =
|
||||
let l, c = (ref ps.line, ref ps.coln) in
|
||||
for i = beg to fin - 1 do
|
||||
if ps.data.[i] = '\n' then
|
||||
(incr l; c := 0)
|
||||
else
|
||||
incr c
|
||||
done;
|
||||
{ ps with line = !l; coln = !c }
|
||||
|
||||
let pret (type a) (x: a): a parser =
|
||||
let fn ps k = k x ps in { fn }
|
||||
|
||||
let pfail error: 'a parser =
|
||||
let fn ps _ = raise (ParseError {error; ps})
|
||||
in { fn }
|
||||
|
||||
let por: 'a parser -> 'a parser -> 'a parser =
|
||||
fun p1 p2 ->
|
||||
let fn ps k =
|
||||
try p1.fn ps k with ParseError e1 ->
|
||||
try p2.fn ps k with ParseError e2 ->
|
||||
if e1.ps.indx > e2.ps.indx then
|
||||
raise (ParseError e1)
|
||||
else
|
||||
raise (ParseError e2)
|
||||
in { fn }
|
||||
|
||||
let pbind: 'a parser -> ('a -> 'b parser) -> 'b parser =
|
||||
fun p1 p2 ->
|
||||
let fn ps k =
|
||||
p1.fn ps (fun x ps -> (p2 x).fn ps k)
|
||||
in { fn }
|
||||
|
||||
(* handy for recursive rules *)
|
||||
let papp p x = pbind (pret x) p
|
||||
|
||||
let psnd: 'a parser -> 'b parser -> 'b parser =
|
||||
fun p1 p2 -> pbind p1 (fun _x -> p2)
|
||||
|
||||
let pfst: 'a parser -> 'b parser -> 'a parser =
|
||||
fun p1 p2 -> pbind p1 (fun x -> psnd p2 (pret x))
|
||||
|
||||
module Infix = struct
|
||||
let ( let* ) = pbind
|
||||
let ( ||| ) = por
|
||||
let ( |<< ) = pfst
|
||||
let ( |>> ) = psnd
|
||||
end
|
||||
|
||||
open Infix
|
||||
|
||||
let pre: ?what:string -> string -> string parser =
|
||||
fun ?what re ->
|
||||
let what =
|
||||
match what with
|
||||
| None -> Printf.sprintf "%S" re
|
||||
| Some what -> what
|
||||
and re = Str.regexp re in
|
||||
let fn ps k =
|
||||
if not (Str.string_match re ps.data ps.indx) then
|
||||
(let error =
|
||||
Printf.sprintf "expected to match %s" what in
|
||||
raise (ParseError {error; ps}));
|
||||
let ps =
|
||||
let indx = Str.match_end () in
|
||||
{ (update_pos ps ps.indx indx) with indx }
|
||||
in
|
||||
k (Str.matched_string ps.data) ps
|
||||
in { fn }
|
||||
|
||||
let peoi: unit parser =
|
||||
let fn ps k =
|
||||
if ps.indx <> String.length ps.data then
|
||||
raise (ParseError
|
||||
{ error = "expected end of input"; ps });
|
||||
k () ps
|
||||
in { fn }
|
||||
|
||||
let pws = pre "[ \r\n\t*]*"
|
||||
let pws1 = pre "[ \r\n\t*]+"
|
||||
|
||||
let pthen p1 p2 =
|
||||
let* x1 = p1 in
|
||||
let* x2 = p2 in
|
||||
pret (x1, x2)
|
||||
|
||||
let rec plist_tail: 'a parser -> ('a list) parser =
|
||||
fun pitem ->
|
||||
(pws |>> pre ")" |>> pret []) |||
|
||||
(let* itm = pitem in
|
||||
let* itms = plist_tail pitem in
|
||||
pret (itm :: itms))
|
||||
|
||||
let plist pitem =
|
||||
pws |>> pre ~what:"a list" "("
|
||||
|>> plist_tail pitem
|
||||
|
||||
let plist1p p1 pitem =
|
||||
pws |>> pre ~what:"a list" "("
|
||||
|>> pthen p1 (plist_tail pitem)
|
||||
|
||||
let ppair p1 p2 =
|
||||
pws |>> pre ~what:"a pair" "("
|
||||
|>> pthen p1 p2 |<< pws |<< pre ")"
|
||||
|
||||
let run_parser p s =
|
||||
let ps =
|
||||
{data = s; line = 1; coln = 0; indx = 0} in
|
||||
try `Ok (p.fn ps (fun res _ps -> res))
|
||||
with ParseError e ->
|
||||
let rec bol i =
|
||||
if i = 0 then i else
|
||||
if i < String.length s && s.[i] = '\n'
|
||||
then i+1 (* XXX BUG *)
|
||||
else bol (i-1)
|
||||
in
|
||||
let rec eol i =
|
||||
if i = String.length s then i else
|
||||
if s.[i] = '\n' then i else
|
||||
eol (i+1)
|
||||
in
|
||||
let bol = bol e.ps.indx in
|
||||
let eol = eol e.ps.indx in
|
||||
(*
|
||||
Printf.eprintf "bol:%d eol:%d indx:%d len:%d\n"
|
||||
bol eol e.ps.indx (String.length s); (* XXX debug *)
|
||||
*)
|
||||
let lines =
|
||||
String.split_on_char '\n'
|
||||
(String.sub s bol (eol - bol))
|
||||
in
|
||||
let nl = List.length lines in
|
||||
let caret = ref (e.ps.indx - bol) in
|
||||
let msg = ref [] in
|
||||
let pfx = " > " in
|
||||
lines |> List.iteri (fun ln l ->
|
||||
if ln <> nl - 1 || l <> "" then begin
|
||||
let ll = String.length l + 1 in
|
||||
msg := (pfx ^ l ^ "\n") :: !msg;
|
||||
if !caret <= ll then begin
|
||||
let pad = String.make !caret ' ' in
|
||||
msg := (pfx ^ pad ^ "^\n") :: !msg;
|
||||
end;
|
||||
caret := !caret - ll;
|
||||
end;
|
||||
);
|
||||
`Error
|
||||
( e.ps, e.error
|
||||
, String.concat "" (List.rev !msg) )
|
||||
|
||||
(* ---------------------------------------- *)
|
||||
(* pattern parsing *)
|
||||
(* ---------------------------------------- *)
|
||||
(* Example syntax:
|
||||
|
||||
(with-vars (a b c d)
|
||||
(patterns
|
||||
(ob (add (tmp a) (con d)))
|
||||
(bsm (add (tmp b) (mul (tmp m) (con 2 4 8)))) ))
|
||||
*)
|
||||
open Match
|
||||
|
||||
let pint64 =
|
||||
let* s = pre "[-]?[0-9_]+" in
|
||||
pret (Int64.of_string s)
|
||||
|
||||
let pid =
|
||||
pre ~what:"an identifer"
|
||||
"[a-zA-Z][a-zA-Z0-9_]*"
|
||||
|
||||
let pop_base =
|
||||
let sob, obs = show_op_base, op_bases in
|
||||
let* s = pre ~what:"an operator"
|
||||
(String.concat "\\|" (List.map sob obs))
|
||||
in pret (List.find (fun o -> s = sob o) obs)
|
||||
|
||||
let pop = let* ob = pop_base in pret (Kl, ob)
|
||||
|
||||
let rec ppat vs =
|
||||
let pcons_tail =
|
||||
let* cs = plist_tail (pws1 |>> pint64) in
|
||||
match cs with
|
||||
| [] -> pret [AnyCon]
|
||||
| _ -> pret (List.map (fun c -> Con c) cs)
|
||||
in
|
||||
let pvar =
|
||||
let* id = pid in
|
||||
if not (List.mem id vs) then
|
||||
pfail ("unbound variable: " ^ id)
|
||||
else
|
||||
pret id
|
||||
in
|
||||
pws |>> (
|
||||
( let* c = pint64 in pret [Atm (Con c)] )
|
||||
|||
|
||||
( pre "(con)" |>> pret [Atm AnyCon] ) |||
|
||||
( let* cs = pre "(con" |>> pcons_tail in
|
||||
pret (List.map (fun c -> Atm c) cs) ) |||
|
||||
( let* v = pre "(con" |>> pws1 |>> pvar in
|
||||
let* cs = pcons_tail in
|
||||
pret (List.map (fun c -> Var (v, c)) cs) )
|
||||
|||
|
||||
( pre "(tmp)" |>> pret [Atm Tmp] ) |||
|
||||
( let* v = pre "(tmp" |>> pws1 |>> pvar in
|
||||
pws |>> pre ")" |>> pret [Var (v, Tmp)] )
|
||||
|||
|
||||
( let* (op, rands) =
|
||||
plist1p (pws |>> pop) (papp ppat vs) in
|
||||
let nrands = List.length rands in
|
||||
if nrands < 2 then
|
||||
pfail ( "binary op requires at least"
|
||||
^ " two arguments" )
|
||||
else
|
||||
let mk x y = Bnr (op, x, y) in
|
||||
pret
|
||||
(products rands []
|
||||
(fun rands pats ->
|
||||
(* construct a left-heavy tree *)
|
||||
let r0 = List.hd rands in
|
||||
let rs = List.tl rands in
|
||||
List.fold_left mk r0 rs :: pats)) )
|
||||
)
|
||||
|
||||
let pwith_vars ?(vs = []) p =
|
||||
( let* vs =
|
||||
pws |>> pre "(with-vars" |>> pws |>>
|
||||
plist (pws |>> pid)
|
||||
in pws |>> p vs |<< pws |<< pre ")" )
|
||||
||| p vs
|
||||
|
||||
let ppats =
|
||||
pwith_vars @@ fun vs ->
|
||||
pre "(patterns" |>> plist_tail
|
||||
(pwith_vars ~vs @@ fun vs ->
|
||||
let* n, ps = ppair pid (ppat vs) in
|
||||
pret (n, vs, ps))
|
||||
|
||||
(* ---------------------------------------- *)
|
||||
(* tests *)
|
||||
(* ---------------------------------------- *)
|
||||
|
||||
let () =
|
||||
if false then
|
||||
let show_patterns ps =
|
||||
"[" ^ String.concat "; "
|
||||
(List.map show_pattern ps) ^ "]"
|
||||
in
|
||||
let pat s =
|
||||
Printf.printf "parse %s = " s;
|
||||
let vars =
|
||||
[ "foobar"; "a"; "b"; "d"
|
||||
; "m"; "s"; "x" ]
|
||||
in
|
||||
match run_parser (ppat vars) s with
|
||||
| `Ok p ->
|
||||
Printf.printf "%s\n" (show_patterns p)
|
||||
| `Error (_, e, _) ->
|
||||
Printf.printf "ERROR: %s\n" e
|
||||
in
|
||||
pat "42";
|
||||
pat "(tmp)";
|
||||
pat "(tmp foobar)";
|
||||
pat "(con)";
|
||||
pat "(con 1 2 3)";
|
||||
pat "(con x 1 2 3)";
|
||||
pat "(add 1 2)";
|
||||
pat "(add 1 2 3 4)";
|
||||
pat "(sub 1 2)";
|
||||
pat "(sub 1 2 3)";
|
||||
pat "(tmp unbound_var)";
|
||||
pat "(add 0)";
|
||||
pat "(add 1 (add 2 3))";
|
||||
pat "(add (tmp a) (con d))";
|
||||
pat "(add (tmp b) (mul (tmp m) (con s 2 4 8)))";
|
||||
pat "(add (con 1 2) (con 3 4))";
|
||||
()
|
134
tools/mgen/test.ml
Normal file
134
tools/mgen/test.ml
Normal file
|
@ -0,0 +1,134 @@
|
|||
open Match
|
||||
open Fuzz
|
||||
open Cgen
|
||||
|
||||
(* unit tests *)
|
||||
|
||||
let test_pattern_match =
|
||||
let pm = pattern_match
|
||||
and nm = fun x y -> not (pattern_match x y) in
|
||||
begin
|
||||
assert (nm (Atm Tmp) (Atm (Con 42L)));
|
||||
assert (pm (Atm AnyCon) (Atm (Con 42L)));
|
||||
assert (nm (Atm (Con 42L)) (Atm AnyCon));
|
||||
assert (nm (Atm (Con 42L)) (Atm Tmp));
|
||||
end
|
||||
|
||||
let test_peel =
|
||||
let o = Kw, Oadd in
|
||||
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 42L)) in
|
||||
let l = peel p () in
|
||||
let () = assert (List.length l = 3) in
|
||||
let atomic_p (p, _) =
|
||||
match p with Atm _ -> true | _ -> false in
|
||||
let () = assert (List.for_all atomic_p l) in
|
||||
let l = List.map (fun (p, c) -> fold_cursor c p) l in
|
||||
let () = assert (List.for_all ((=) p) l) in
|
||||
()
|
||||
|
||||
let test_fold_pairs =
|
||||
let l = [1; 2; 3; 4; 5] in
|
||||
let p = fold_pairs l l [] (fun a b -> a :: b) in
|
||||
let () = assert (List.length p = 25) in
|
||||
let p = sort_uniq compare p in
|
||||
let () = assert (List.length p = 25) in
|
||||
()
|
||||
|
||||
(* test pattern & state *)
|
||||
|
||||
let print_sm oc =
|
||||
StateMap.iter (fun k s' ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
let top =
|
||||
List.fold_left (fun top c ->
|
||||
match c with
|
||||
| Top r -> top ^ " " ^ r
|
||||
| _ -> top) "" s'.point
|
||||
in
|
||||
Printf.fprintf oc
|
||||
" (%s %d %d) -> %d%s\n"
|
||||
(show_op o)
|
||||
sl.id sr.id s'.id top)
|
||||
|
||||
let rules =
|
||||
let oa = Kl, Oadd in
|
||||
let om = Kl, Omul in
|
||||
let va = Var ("a", Tmp)
|
||||
and vb = Var ("b", Tmp)
|
||||
and vc = Var ("c", Tmp)
|
||||
and vs = Var ("s", Tmp) in
|
||||
let vars = ["a"; "b"; "c"; "s"] in
|
||||
let rule name pattern =
|
||||
List.map
|
||||
(fun pattern -> {name; vars; pattern})
|
||||
(ac_equiv pattern)
|
||||
in
|
||||
match `X64Addr with
|
||||
(* ------------------------------- *)
|
||||
| `X64Addr ->
|
||||
(* o + b *)
|
||||
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||
@ (* b + s * m *)
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 2L), vs)))
|
||||
@
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 4L), vs)))
|
||||
@
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 8L), vs)))
|
||||
@ (* b + s *)
|
||||
rule "bs1" (Bnr (oa, vb, vs))
|
||||
@ (* o + s * m *)
|
||||
(* rule "osm" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp))) *) []
|
||||
@ (* o + b + s *)
|
||||
rule "obs1" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb), vs))
|
||||
@ (* o + b + s * m *)
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 2L), vs)))
|
||||
@
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 4L), vs)))
|
||||
@
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 8L), vs)))
|
||||
(* ------------------------------- *)
|
||||
| `Add3 ->
|
||||
[ { name = "add"
|
||||
; vars = []
|
||||
; pattern = Bnr (oa, va, Bnr (oa, vb, vc)) } ] @
|
||||
[ { name = "add"
|
||||
; vars = []
|
||||
; pattern = Bnr (oa, Bnr (oa, va, vb), vc) } ]
|
||||
|
||||
(*
|
||||
|
||||
let sa, am, sm = generate_table rules
|
||||
let () =
|
||||
Array.iteri (fun i s ->
|
||||
Format.printf "@[state %d: %s@]@."
|
||||
i (show_pattern s.seen))
|
||||
sa
|
||||
let () = print_sm stdout sm; flush stdout
|
||||
|
||||
let matcher = lr_matcher sm sa rules "obsm" (* XXX *)
|
||||
let () = Format.printf "@[<v>%a@]@." Action.pp matcher
|
||||
let () = Format.printf "@[matcher size: %d@]@." (Action.size matcher)
|
||||
|
||||
let numbr = make_numberer sa am sm
|
||||
|
||||
let () =
|
||||
let opts = { pfx = ""
|
||||
; static = true
|
||||
; oc = stdout } in
|
||||
emit_c opts numbr;
|
||||
emit_matchers opts
|
||||
[ ( ["b"; "o"; "s"; "m"]
|
||||
, "obsm"
|
||||
, matcher ) ]
|
||||
|
||||
(*
|
||||
let tp = fuzz_numberer rules numbr
|
||||
let () = test_matchers tp numbr rules
|
||||
*)
|
||||
|
||||
*)
|
Loading…
Add table
Reference in a new issue