mgen: match automatons and C generation

The algorithm to generate matchers
took a long time to be discovered
and refined to its present version.
The rest of mgen is mostly boring
engineering.

Extensive fuzzing ensures that the
two core components of mgen (tables
and matchers generation) are correct
on specific problem instances.
This commit is contained in:
Quentin Carbonneaux 2023-05-12 16:44:04 +02:00
parent 56e2263ca4
commit a609527752
11 changed files with 2144 additions and 444 deletions

View file

@ -1,347 +0,0 @@
type cls = Kw | Kl | Ks | Kd
type op_base =
| Oadd
| Osub
| Omul
type op = cls * op_base
let commutative = function
| (_, (Oadd | Omul)) -> true
| (_, _) -> false
let associative = function
| (_, (Oadd | Omul)) -> true
| (_, _) -> false
type atomic_pattern =
| Tmp
| AnyCon
| Con of int64
type pattern =
| Bnr of op * pattern * pattern
| Atm of atomic_pattern
| Var of string * atomic_pattern
let show_op (k, o) =
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul") ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Var _ -> failwith "variable not allowed"
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let rec pattern_match p w =
match p with
| Var _ -> failwith "variable not allowed"
| Atm Tmp ->
begin match w with
| Atm (Con _ | AnyCon) -> false
| _ -> true
end
| Atm (Con _) -> w = p
| Atm (AnyCon) ->
not (pattern_match (Atm Tmp) w)
| Bnr (o, pl, pr) ->
begin match w with
| Bnr (o', wl, wr) ->
o' = o &&
pattern_match pl wl &&
pattern_match pr wr
| _ -> false
end
type 'a cursor = (* a position inside a pattern *)
| Bnrl of op * 'a cursor * pattern
| Bnrr of op * pattern * 'a cursor
| Top of 'a
let rec fold_cursor c p =
match c with
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
| Top _ -> p
let peel p x =
let once out (p, c) =
match p with
| Var _ -> failwith "variable not allowed"
| Atm _ -> (p, c) :: out
| Bnr (o, pl, pr) ->
(pl, Bnrl (o, c, pr)) ::
(pr, Bnrr (o, pl, c)) :: out
in
let rec go l =
let l' = List.fold_left once [] l in
if List.length l' = List.length l
then l
else go l'
in go [(p, Top x)]
let fold_pairs l1 l2 ini f =
let rec go acc = function
| [] -> acc
| a :: l1' ->
go (List.fold_left
(fun acc b -> f (a, b) acc)
acc l2) l1'
in go ini l1
let iter_pairs l f =
fold_pairs l l () (fun x () -> f x)
type 'a state =
{ id: int
; seen: pattern
; point: ('a cursor) list }
let rec binops side {point; _} =
List.filter_map (fun c ->
match c, side with
| Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> None)
point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
compare a b) l |>
List.fold_left (fun (oo, l, res) (o', c) ->
match oo with
| None -> (Some o', [c], [])
| Some o when o = o' -> (oo, c :: l, res)
| Some o -> (Some o', [c], (o, l) :: res))
(None, [], []) |>
(function
| (None, _, _) -> []
| (Some o, l, res) -> (o, l) :: res)
let sort_uniq cmp l =
List.fast_sort cmp l |>
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e when cmp e e' = 0 -> (eo, l)
| Some e -> (Some e', e :: l))
(None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
let normalize (point: ('a cursor) list) =
sort_uniq compare point
let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst in
let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst in
List.map (fun (o, l) ->
o,
{ id = 0
; seen = Bnr (o, s1.seen, s2.seen)
; point = normalize (l @ tmp)
}) (group_by_fst (o1 @ o2))
type p = string
module StateSet : sig
type t
val create: unit -> t
val add: t -> p state ->
[> `Added | `Found ] * p state
val iter: t -> (p state -> unit) -> unit
val elems: t -> (p state) list
end = struct
open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 1 }
let add set s =
assert (s.point = normalize s.point);
try
let id = find set.h s in
`Found, {s with id}
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
Printf.printf "adding: %d [%s]\n"
id (show_pattern s.seen);
add set.h s id;
`Added, {s with id}
end
let iter set f =
let f s id = f {s with id} in
iter f set.h
let elems set =
let res = ref [] in
iter set (fun s -> res := s :: !res);
!res
end
type table_key =
| K of op * p state * p state
module StateMap = Map.Make(struct
type t = table_key
let compare ka kb =
match ka, kb with
| K (o, sl, sr), K (o', sl', sr') ->
compare (o, sl.id, sr.id)
(o', sl'.id, sr'.id)
end)
type rule =
{ name: string
; pattern: pattern
}
let generate_table rl =
let states = StateSet.create () in
(* initialize states *)
let ground =
List.concat_map
(fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let find x d l =
try List.assoc x l with Not_found -> d in
let tmp = find (Atm Tmp) [] ground in
let con = find (Atm AnyCon) [] ground in
let () =
List.iter (fun (seen, l) ->
let point =
if pattern_match (Atm Tmp) seen
then normalize (tmp @ l)
else normalize (con @ l)
in
let s = {id = 0; seen; point} in
let flag, _ = StateSet.add states s in
assert (flag = `Added)
) ground
in
(* setup loop state *)
let map = ref StateMap.empty in
let map_add k s' =
map := StateMap.add k s' !map
in
let flag = ref `Added in
let flagmerge = function
| `Added -> flag := `Added
| _ -> ()
in
(* iterate until fixpoint *)
while !flag = `Added do
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in
flagmerge flag';
map_add (K (o, sl, sr)) s';
));
done;
(StateSet.elems states, !map)
let intersperse x l =
let rec go left right out =
let out =
(List.rev left @ [x] @ right) ::
out in
match right with
| x :: right' ->
go (x :: left) right' out
| [] -> out
in go [] l []
let rec permute = function
| [] -> [[]]
| x :: l ->
List.concat (List.map
(intersperse x) (permute l))
(* build all binary trees with ordered
* leaves l *)
let rec bins build l =
let rec go l r out =
match r with
| [] -> out
| x :: r' ->
go (l @ [x]) r'
(fold_pairs
(bins build l)
(bins build r)
out (fun (l, r) out ->
build l r :: out))
in
match l with
| [] -> []
| [x] -> [x]
| x :: l -> go [x] l []
let products l ini f =
let rec go acc la = function
| [] -> f (List.rev la) acc
| xs :: l ->
List.fold_left (fun acc x ->
go acc (x :: la) l)
acc xs
in go ini [] l
(* combinatorial nuke... *)
let rec ac_equiv =
let rec alevel o = function
| Bnr (o', l, r) when o' = o ->
alevel o l @ alevel o r
| x -> [x]
in function
| Bnr (o, _, _) as p
when associative o ->
products
(List.map ac_equiv (alevel o p)) []
(fun choice out ->
List.map
(bins (fun l r -> Bnr (o, l, r)))
(if commutative o
then permute choice
else [choice]) |>
List.concat |>
(fun l -> List.rev_append l out))
| Bnr (o, l, r)
when commutative o ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) ::
Bnr (o, r, l) :: out)
| Bnr (o, l, r) ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) :: out)
| x -> [x]

View file

@ -1,97 +0,0 @@
#use "match.ml"
let test_pattern_match =
let pm = pattern_match
and nm = fun x y -> not (pattern_match x y) in
begin
assert (nm (Atm Tmp) (Atm (Con 42L)));
assert (pm (Atm AnyCon) (Atm (Con 42L)));
assert (nm (Atm (Con 42L)) (Atm AnyCon));
assert (nm (Atm (Con 42L)) (Atm Tmp));
end
let test_peel =
let o = Kw, Oadd in
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
Atm (Con 42L)) in
let l = peel p () in
let () = assert (List.length l = 3) in
let atomic_p (p, _) =
match p with Atm _ -> true | _ -> false in
let () = assert (List.for_all atomic_p l) in
let l = List.map (fun (p, c) -> fold_cursor c p) l in
let () = assert (List.for_all ((=) p) l) in
()
let test_fold_pairs =
let l = [1; 2; 3; 4; 5] in
let p = fold_pairs l l [] (fun a b -> a :: b) in
let () = assert (List.length p = 25) in
let p = sort_uniq compare p in
let () = assert (List.length p = 25) in
()
(* test pattern & state *)
let tp =
let o = Kw, Oadd in
Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
Atm (Con 0L))
let ts =
{ id = 0
; seen = Atm Tmp
; point =
List.map snd
(List.filter (fun (p, _) -> p = Atm Tmp)
(peel tp ()))
}
let print_sm =
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.printf
"(%s %d %d) -> %d%s\n"
(show_op o)
sl.id sr.id s'.id top)
let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
match `X64Addr with
(* ------------------------------- *)
| `X64Addr ->
let rule name pattern =
List.mapi (fun i pattern ->
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
; pattern })
(ac_equiv pattern) in
(* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * i *)
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* o + s * i *)
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* b + o + s * i *)
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
[ { name = "add"
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
[ { name = "mul"
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
Atm Tmp),
Bnr (oa, Atm Tmp,
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
let sl, sm = generate_table rules
let s n = List.find (fun {id; _} -> id = n) sl
let () = print_sm sm

3
tools/mgen/.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
*.cm[iox]
*.o
mgen

1
tools/mgen/.ocp-indent Normal file
View file

@ -0,0 +1 @@
match_clause=4

16
tools/mgen/Makefile Normal file
View file

@ -0,0 +1,16 @@
BIN = mgen
SRC = \
match.ml \
fuzz.ml \
cgen.ml \
sexp.ml \
test.ml \
main.ml
$(BIN): $(SRC)
ocamlopt -o $(BIN) -g str.cmxa $(SRC)
clean:
rm -f *.cm? *.o $(BIN)
.PHONY: clean

420
tools/mgen/cgen.ml Normal file
View file

@ -0,0 +1,420 @@
open Match
type options =
{ pfx: string
; static: bool
; oc: out_channel }
type side = L | R
type id_pred =
| InBitSet of Int64.t
| Ge of int
| Eq of int
and id_test =
| Pred of (side * id_pred)
| And of id_test * id_test
type case_code =
| Table of ((int * int) * int) list
| IfThen of
{ test: id_test
; cif: case_code
; cthen: case_code option }
| Return of int
type case =
{ swap: bool
; code: case_code }
let cgen_case tmp nstates map =
let cgen_test ids =
match ids with
| [id] -> Eq id
| _ ->
let min_id =
List.fold_left min max_int ids in
if List.length ids = nstates - min_id
then Ge min_id
else begin
assert (nstates <= 64);
InBitSet
(List.fold_left (fun bs id ->
Int64.logor bs
(Int64.shift_left 1L id))
0L ids)
end
in
let symmetric =
let inverse ((l, r), x) = ((r, l), x) in
setify map = setify (List.map inverse map) in
let map =
let ordered ((l, r), _) = r <= l in
if symmetric then
List.filter ordered map
else map
in
let exception BailToTable in
try
let st =
match setify (List.map snd map) with
| [st] -> st
| _ -> raise BailToTable
in
(* the operation considered can only
* generate a single state *)
let pairs = List.map fst map in
let ls, rs = List.split pairs in
let ls = setify ls and rs = setify rs in
if List.length ls > 1 && List.length rs > 1 then
raise BailToTable;
{ swap = symmetric
; code =
let pl = Pred (L, cgen_test ls)
and pr = Pred (R, cgen_test rs) in
IfThen
{ test = And (pl, pr)
; cif = Return st
; cthen = Some (Return tmp) } }
with BailToTable ->
{ swap = symmetric
; code = Table map }
let show_op (_cls, op) =
"O" ^ show_op_base op
let indent oc i =
Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
let emit_swap oc i =
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
pfi i "if (l < r)\n";
pfi (i+1) "t = l, l = r, r = t;\n"
let gen_tables oc tmp pfx nstates (op, c) =
let i = 1 in
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
let ntables = ref 0 in
(* we must follow the order in which
* we visit code in emit_case, or
* else ntables goes out of sync *)
let base = pfx ^ show_op op in
let swap = c.swap in
let rec gen c =
match c with
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
assert (nstates <= 256);
if swap then
let n = nstates * (nstates + 1) / 2 in
pfi i "static uchar %stbl[%d] = {\n" name n
else
pfi i "static uchar %stbl[%d][%d] = {\n"
name nstates nstates;
for l = 0 to nstates - 1 do
pfi (i+1) "";
for r = 0 to nstates - 1 do
if not swap || r <= l then
begin
pf "%d"
(try List.assoc (l,r) map
with Not_found -> tmp);
pf ",";
end
done;
pf "\n";
done;
pfi i "};\n"
| IfThen {cif; cthen} ->
gen cif;
Option.iter gen cthen
| Return _ -> ()
in
gen c.code
let emit_case oc pfx no_swap (op, c) =
let fpf = Printf.fprintf in
let pf m = fpf oc m in
let pfi n m = indent oc n; pf m in
let rec side oc = function
| L -> fpf oc "l"
| R -> fpf oc "r"
in
let pred oc (s, pred) =
match pred with
| InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
| Eq id -> fpf oc "%a == %d" side s id
| Ge id -> fpf oc "%d <= %a" id side s
in
let base = pfx ^ show_op op in
let swap = c.swap in
let ntables = ref 0 in
let rec code i c =
match c with
| Return id -> pfi i "return %d;\n" id
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
incr ntables;
if swap then
pfi i "return %stbl[(l + l*l)/2 + r];\n" name
else pfi i "return %stbl[l][r];\n" name
| IfThen ({test = And (And (t1, t2), t3)} as r) ->
code i @@ IfThen
{r with test = And (t1, And (t2, t3))}
| IfThen {test = And (Pred p, t); cif; cthen} ->
pfi i "if (%a)\n" pred p;
code i (IfThen {test = t; cif; cthen})
| IfThen {test = Pred p; cif; cthen} ->
pfi i "if (%a) {\n" pred p;
code (i+1) cif;
pfi i "}\n";
Option.iter (code i) cthen
in
pfi 1 "case %s:\n" (show_op op);
if not no_swap && c.swap then
emit_swap oc 2;
code 2 c.code
let emit_list
?(limit=60) ?(cut_before_sep=false)
~col ~indent:i ~sep ~f oc l =
let sl = String.length sep in
let rstripped_sep, rssl =
if sep.[sl - 1] = ' ' then
String.sub sep 0 (sl - 1), sl - 1
else sep, sl
in
let lstripped_sep, lssl =
if sep.[0] = ' ' then
String.sub sep 1 (sl - 1), sl - 1
else sep, sl
in
let rec line col acc = function
| [] -> (List.rev acc, [])
| s :: l ->
let col = col + sl + String.length s in
let no_space =
if cut_before_sep || l = [] then
col > limit
else
col + rssl > limit
in
if no_space then
(List.rev acc, s :: l)
else
line col (s :: acc) l
in
let rec go col l =
if l = [] then () else
let ll, l = line col [] l in
Printf.fprintf oc "%s" (String.concat sep ll);
if l <> [] && cut_before_sep then begin
Printf.fprintf oc "\n";
indent oc i;
Printf.fprintf oc "%s" lstripped_sep;
go (8*i + lssl) l
end else if l <> [] then begin
Printf.fprintf oc "%s\n" rstripped_sep;
indent oc i;
go (8*i) l
end else ()
in
go col (List.map f l)
let emit_numberer opts n =
let pf m = Printf.fprintf opts.oc m in
let tmp = (atom_state n Tmp).id in
let con = (atom_state n AnyCon).id in
let nst = Array.length n.states in
let cases =
StateMap.by_ops n.statemap |>
List.map (fun (op, map) ->
(op, cgen_case tmp nst map))
in
let all_swap =
List.for_all (fun (_, c) -> c.swap) cases in
(* opn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%sopn(int op, int l, int r)\n" opts.pfx;
pf "{\n";
cases |> List.iter
(gen_tables opts.oc tmp opts.pfx nst);
if List.exists (fun (_, c) -> c.swap) cases then
pf "\tint t;\n\n";
if all_swap then emit_swap opts.oc 1;
pf "\tswitch (op) {\n";
cases |> List.iter
(emit_case opts.oc opts.pfx all_swap);
pf "\tdefault:\n";
pf "\t\treturn %d;\n" tmp;
pf "\t}\n";
pf "}\n\n";
(* refn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
pf "{\n";
let cons =
List.filter_map (function
| (Con c, s) -> Some (c, s.id)
| _ -> None)
n.atoms
in
if cons <> [] then
pf "\tint64_t n;\n\n";
pf "\tswitch (rtype(r)) {\n";
pf "\tcase RTmp:\n";
if tmp <> 0 then begin
assert
(List.exists (fun (_, s) ->
s.id = 0
) n.atoms &&
(* no temp should ever get state 0 *)
List.for_all (fun (a, s) ->
s.id <> 0 ||
match a with
| AnyCon | Con _ -> true
| _ -> false
) n.atoms);
pf "\t\tif (!tn[r.val].n)\n";
pf "\t\t\ttn[r.val].n = %d;\n" tmp;
end;
pf "\t\treturn tn[r.val].n;\n";
pf "\tcase RCon:\n";
if cons <> [] then begin
pf "\t\tif (con[r.val].type != CBits)\n";
pf "\t\t\treturn %d;\n" con;
pf "\t\tn = con[r.val].bits.i;\n";
cons |> inverse |> group_by_fst
|> List.iter (fun (id, cs) ->
pf "\t\tif (";
emit_list ~cut_before_sep:true
~col:20 ~indent:2 ~sep:" || "
~f:(fun c -> "n == " ^ Int64.to_string c)
opts.oc cs;
pf ")\n";
pf "\t\t\treturn %d;\n" id
);
end;
pf "\t\treturn %d;\n" con;
pf "\tdefault:\n";
pf "\t\treturn INT_MIN;\n";
pf "\t}\n";
pf "}\n\n";
(* match[]: patterns per state *)
if opts.static then pf "static ";
pf "bits %smatch[%d] = {\n" opts.pfx nst;
n.states |> Array.iteri (fun sn s ->
let tops =
List.filter_map (function
| Top ("$" | "%") -> None
| Top r -> Some ("BIT(P" ^ r ^ ")")
| _ -> None) s.point |> setify
in
if tops <> [] then
pf "\t[%d] = %s,\n"
sn (String.concat " | " tops);
);
pf "};\n\n"
let var_id vars f =
List.mapi (fun i x -> (x, i)) vars |>
List.assoc f
let compile_action vars act =
let pcs = Hashtbl.create 100 in
let rec gen pc (act: Action.t) =
try
[10 + Hashtbl.find pcs act.id]
with Not_found ->
let code =
match act.node with
| Action.Stop ->
[0]
| Action.Push (sym, k) ->
let c = if sym then 1 else 2 in
[c] @ gen (pc + 1) k
| Action.Set (v, {node = Action.Pop k; _})
| Action.Set (v, ({node = Action.Stop; _} as k)) ->
let v = var_id vars v in
[3; v] @ gen (pc + 2) k
| Action.Set _ ->
(* for now, only atomic patterns can be
* tied to a variable, so Set must be
* followed by either Pop or Stop *)
assert false
| Action.Pop k ->
[4] @ gen (pc + 1) k
| Action.Switch cases ->
let cases =
inverse cases |> group_by_fst |>
List.sort (fun (_, cs1) (_, cs2) ->
let n1 = List.length cs1
and n2 = List.length cs2 in
compare n2 n1)
in
(* the last case is the one with
* the max number of entries *)
let cases = List.rev (List.tl cases)
and last = fst (List.hd cases) in
let ncases =
List.fold_left (fun n (_, cs) ->
List.length cs + n)
0 cases
in
let body_off = 2 + 2 * ncases + 1 in
let pc, tbl, body =
List.fold_left
(fun (pc, tbl, body) (a, cs) ->
let ofs = body_off + List.length body in
let case = gen pc a in
let pc = pc + List.length case in
let body = body @ case in
let tbl =
List.fold_left (fun tbl c ->
tbl @ [c; ofs]
) tbl cs
in
(pc, tbl, body))
(pc + body_off, [], [])
cases
in
let ofs = body_off + List.length body in
let tbl = tbl @ [ofs] in
assert (2 + List.length tbl = body_off);
[5; ncases] @ tbl @ body @ gen pc last
in
if act.node <> Action.Stop then
Hashtbl.replace pcs act.id pc;
code
in
gen 0 act
let emit_matchers opts ms =
let pf m = Printf.fprintf opts.oc m in
if opts.static then pf "static ";
pf "uchar *%smatcher[] = {\n" opts.pfx;
List.iter (fun (vars, pname, m) ->
pf "\t[P%s] = (uchar[]){\n" pname;
pf "\t\t";
let bytes = compile_action vars m in
emit_list
~col:16 ~indent:2 ~sep:","
~f:string_of_int opts.oc bytes;
pf "\n";
pf "\t},\n")
ms;
pf "};\n\n"
let emit_c opts n =
emit_numberer opts n

413
tools/mgen/fuzz.ml Normal file
View file

@ -0,0 +1,413 @@
(* fuzz the tables and matchers generated *)
open Match
module Buffer: sig
type 'a t
val create: ?capacity:int -> unit -> 'a t
val reset: 'a t -> unit
val size: 'a t -> int
val get: 'a t -> int -> 'a
val set: 'a t -> int -> 'a -> unit
val push: 'a t -> 'a -> unit
end = struct
type 'a t =
{ mutable size: int
; mutable data: 'a array }
let mk_array n = Array.make n (Obj.magic 0)
let create ?(capacity = 10) () =
if capacity < 0 then invalid_arg "Buffer.make";
{size = 0; data = mk_array capacity}
let reset b = b.size <- 0
let size b = b.size
let get b n =
if n >= size b then invalid_arg "Buffer.get";
b.data.(n)
let set b n x =
if n >= size b then invalid_arg "Buffer.set";
b.data.(n) <- x
let push b x =
let cap = Array.length b.data in
if size b = cap then begin
let data = mk_array (2 * cap + 1) in
Array.blit b.data 0 data 0 cap;
b.data <- data
end;
let sz = size b in
b.size <- sz + 1;
set b sz x
end
let binop_state n op s1 s2 =
let key = K (op, s1, s2) in
try StateMap.find key n.statemap
with Not_found -> atom_state n Tmp
type id = int
type term_data =
| Binop of op * id * id
| Leaf of atomic_pattern
type term =
{ id: id
; data: term_data
; state: p state }
let pp_term fmt (ta, id) =
let fpf x = Format.fprintf fmt x in
let rec pp _fmt id =
match ta.(id).data with
| Leaf (Con c) -> fpf "%Ld" c
| Leaf AnyCon -> fpf "$%d" id
| Leaf Tmp -> fpf "%%%d" id
| Binop (op, id1, id2) ->
fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
(show_op op) id ta.(id).state.id
pp id1 pp id2
in pp fmt id
(* A term pool is a deduplicated set of term
* that maintains nodes numbering using the
* statemap passed at creation time *)
module TermPool = struct
type t =
{ terms: term Buffer.t
; hcons: (term_data, id) Hashtbl.t
; numbr: numberer }
let create numbr =
{ terms = Buffer.create ()
; hcons = Hashtbl.create 100
; numbr }
let reset tp =
Buffer.reset tp.terms;
Hashtbl.clear tp.hcons
let size tp = Buffer.size tp.terms
let term tp id = Buffer.get tp.terms id
let mk_leaf tp atm =
let data = Leaf atm in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state = atom_state tp.numbr atm in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let mk_binop tp op t1 t2 =
let data = Binop (op, t1.id, t2.id) in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state =
binop_state tp.numbr op t1.state t2.state
in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let rec add_pattern tp = function
| Bnr (op, p1, p2) ->
let t1 = add_pattern tp p1 in
let t2 = add_pattern tp p2 in
mk_binop tp op t1 t2
| Atm atm -> mk_leaf tp atm
| Var (_, atm) -> add_pattern tp (Atm atm)
let explode_term tp id =
let rec aux tms n id =
let t = term tp id in
match t.data with
| Leaf _ -> (n, {t with id = n} :: tms)
| Binop (op, id1, id2) ->
let n1, tms = aux tms n id1 in
let n = n1 + 1 in
let n2, tms = aux tms n id2 in
let n = n2 + 1 in
(n, { t with data = Binop (op, n1, n2)
; id = n } :: tms)
in
let n, tms = aux [] 0 id in
Array.of_list (List.rev tms), n
end
module R = Random
(* uniform pick in a list *)
let list_pick l =
let rec aux n l x =
match l with
| [] -> x
| y :: l ->
if R.int (n + 1) = 0 then
aux (n + 1) l y
else
aux (n + 1) l x
in
match l with
| [] -> invalid_arg "list_pick"
| x :: l -> aux 1 l x
let term_pick ~numbr =
let ops =
if numbr.ops = [] then
numbr.ops <-
(StateMap.fold (fun k _ ops ->
match k with
| K (op, _, _) -> op :: ops)
numbr.statemap [] |> setify);
numbr.ops
in
let rec gen depth =
(* exponential probability for leaves to
* avoid skewing towards shallow terms *)
let atm_prob = 0.75 ** float_of_int depth in
if R.float 1.0 <= atm_prob || ops = [] then
let atom, st = list_pick numbr.atoms in
(st, Atm atom)
else
let op = list_pick ops in
let s1, t1 = gen (depth - 1) in
let s2, t2 = gen (depth - 1) in
( binop_state numbr op s1 s2
, Bnr (op, t1, t2) )
in fun ~depth -> gen depth
exception FuzzError
let rec pattern_depth = function
| Bnr (_, p1, p2) ->
1 + max (pattern_depth p1) (pattern_depth p2)
| Atm _ -> 0
| Var (_, atm) -> pattern_depth (Atm atm)
let ( %% ) a b =
1e2 *. float_of_int a /. float_of_int b
let progress ?(width = 50) msg pct =
Format.eprintf "\x1b[2K\r%!";
let progress_bar fmt =
let n =
let fwidth = float_of_int width in
1 + int_of_float (pct *. fwidth /. 1e2)
in
Format.fprintf fmt " %s%s %.0f%%@?"
(String.concat "" (List.init n (fun _ -> "")))
(String.make (max 0 (width - n)) '-')
pct
in
Format.kfprintf progress_bar
Format.err_formatter msg
let fuzz_numberer rules numbr =
(* pick twice the max pattern depth so we
* have a chance to find non-trivial numbers
* for the atomic patterns in the rules *)
let depth =
List.fold_left (fun depth r ->
max depth (pattern_depth r.pattern))
0 rules * 2
in
(* fuzz until the term pool we are constructing
* is no longer growing fast enough; or we just
* went through sufficiently many iterations *)
let max_iter = 1_000_000 in
let low_insert_rate = 1e-2 in
let tp = TermPool.create numbr in
let rec loop new_stats i =
let (_, _, insert_rate) = new_stats in
if insert_rate <= low_insert_rate then () else
if i >= max_iter then () else
(* periodically update stats *)
let new_stats =
let (num, cnt, rate) = new_stats in
if num land 1023 = 0 then
let rate =
0.5 *. (rate +. float_of_int cnt /. 1023.)
in
progress " insert_rate=%.1f%%"
(i %% max_iter) (rate *. 1e2);
(num + 1, 0, rate)
else new_stats
in
(* create a term and check that its number is
* accurate wrt the rules *)
let st, term = term_pick ~numbr ~depth in
let state_matched =
List.filter_map (fun cu ->
match cu with
| Top ("$" | "%") -> None
| Top name -> Some name
| _ -> None)
st.point |> setify
in
let rule_matched =
List.filter_map (fun r ->
if pattern_match r.pattern term then
Some r.name
else None)
rules |> setify
in
if state_matched <> rule_matched then begin
let open Format in
let pp_str_list =
let pp_sep fmt () = fprintf fmt ",@ " in
pp_print_list ~pp_sep pp_print_string
in
eprintf "@.@[<v2>fuzz error for %s"
(show_pattern term);
eprintf "@ @[state matched: %a@]"
pp_str_list state_matched;
eprintf "@ @[rule matched: %a@]"
pp_str_list rule_matched;
eprintf "@]@.";
raise FuzzError;
end;
if state_matched = [] then
loop new_stats (i + 1)
else
(* add to the term pool *)
let old_size = TermPool.size tp in
let _ = TermPool.add_pattern tp term in
let new_stats =
let (num, cnt, rate) = new_stats in
if TermPool.size tp <> old_size then
(num + 1, cnt + 1, rate)
else
(num + 1, cnt, rate)
in
loop new_stats (i + 1)
in
loop (1, 0, 1.0) 0;
Format.eprintf
"@.@[ generated %.3fMiB of test terms@]@."
(float_of_int (Obj.reachable_words (Obj.repr tp))
/. 128. /. 1024.);
tp
let rec run_matcher stk m (ta, id as t) =
let state id = ta.(id).state.id in
match m.Action.node with
| Action.Switch cases ->
let m =
try List.assoc (state id) cases
with Not_found -> failwith "no switch case"
in
run_matcher stk m t
| Action.Push (sym, m) ->
let l, r =
match ta.(id).data with
| Leaf _ -> failwith "push on leaf"
| Binop (_, l, r) -> (l, r)
in
if sym && state l > state r
then run_matcher (l :: stk) m (ta, r)
else run_matcher (r :: stk) m (ta, l)
| Action.Pop m -> begin
match stk with
| id :: stk -> run_matcher stk m (ta, id)
| [] -> failwith "pop on empty stack"
end
| Action.Set (v, m) ->
(v, id) :: run_matcher stk m t
| Action.Stop -> []
let rec term_match p (ta, id) =
let (|>>) x f =
match x with None -> None | Some x -> f x
in
let atom_match a =
match ta.(id).data with
| Leaf a' -> pattern_match (Atm a) (Atm a')
| Binop _ -> pattern_match (Atm a) (Atm Tmp)
in
match p with
| Var (v, a) when atom_match a ->
Some [(v, id)]
| Atm a when atom_match a -> Some []
| (Atm _ | Var _) -> None
| Bnr (op, pl, pr) -> begin
match ta.(id).data with
| Binop (op', idl, idr) when op' = op ->
term_match pl (ta, idl) |>> fun l1 ->
term_match pr (ta, idr) |>> fun l2 ->
Some (l1 @ l2)
| _ -> None
end
let test_matchers tp numbr rules =
let {statemap = sm; states = sa; _} = numbr in
let total = ref 0 in
let matchers =
let htbl = Hashtbl.create (Array.length sa) in
List.map (fun r -> (r.name, r.pattern)) rules |>
group_by_fst |>
List.iter (fun (r, ps) ->
total := !total + List.length ps;
let pm = (ps, lr_matcher sm sa rules r) in
sa |> Array.iter (fun s ->
if List.mem (Top r) s.point then
Hashtbl.add htbl s.id pm));
htbl
in
let seen = Hashtbl.create !total in
for id = 0 to TermPool.size tp - 1 do
if id land 1023 = 0 ||
id = TermPool.size tp - 1 then begin
progress
" coverage=%.1f%%"
(id %% TermPool.size tp)
(Hashtbl.length seen %% !total)
end;
let t = TermPool.explode_term tp id in
Hashtbl.find_all matchers
(TermPool.term tp id).state.id |>
List.iter (fun (ps, m) ->
let norm = List.fast_sort compare in
let ok =
match norm (run_matcher [] m t) with
| asn -> `Match (List.exists (fun p ->
match term_match p t with
| None -> false
| Some asn' ->
if asn = norm asn' then begin
Hashtbl.replace seen p ();
true
end else false) ps)
| exception e -> `RunFailure e
in
if ok <> `Match true then begin
let open Format in
let pp_asn fmt asn =
fprintf fmt "@[<h>";
pp_print_list
~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
(fun fmt (v, d) ->
fprintf fmt "@[%s←%d@]" v d)
fmt asn;
fprintf fmt "@]"
in
eprintf "@.@[<v2>matcher error for";
eprintf "@ @[%a@]" pp_term t;
begin match ok with
| `RunFailure e ->
eprintf "@ @[exception: %s@]"
(Printexc.to_string e)
| `Match (* false *) _ ->
let asn = run_matcher [] m t in
eprintf "@ @[assignment: %a@]"
pp_asn asn;
eprintf "@ @[<v2>could not match";
List.iter (fun p ->
eprintf "@ + @[%s@]"
(show_pattern p)) ps;
eprintf "@]"
end;
eprintf "@]@.";
raise FuzzError
end)
done;
Format.eprintf "@."

214
tools/mgen/main.ml Normal file
View file

@ -0,0 +1,214 @@
open Cgen
open Match
let mgen ~verbose ~fuzz path lofs input oc =
let info ?(level = 1) fmt =
if level <= verbose then
Printf.eprintf fmt
else
Printf.ifprintf stdout fmt
in
let rules =
match Sexp.(run_parser ppats) input with
| `Error (ps, err, loc) ->
Printf.eprintf "%s:%d:%d %s\n"
path (lofs + ps.Sexp.line) ps.Sexp.coln err;
Printf.eprintf "%s" loc;
exit 1
| `Ok rules -> rules
in
info "adding ac variants...%!";
let nparsed =
List.fold_left
(fun npats (_, _, ps) ->
npats + List.length ps)
0 rules
in
let varsmap = Hashtbl.create 10 in
let rules =
List.concat_map (fun (name, vars, patterns) ->
(try assert (Hashtbl.find varsmap name = vars)
with Not_found -> ());
Hashtbl.replace varsmap name vars;
List.map
(fun pattern -> {name; vars; pattern})
(List.concat_map ac_equiv patterns)
) rules
in
info " %d -> %d patterns\n"
nparsed (List.length rules);
let rnames =
setify (List.map (fun r -> r.name) rules) in
info "generating match tables...%!";
let sa, am, sm = generate_table rules in
let numbr = make_numberer sa am sm in
info " %d states, %d rules\n"
(Array.length sa) (StateMap.cardinal sm);
if verbose >= 2 then begin
info "-------------\nstates:\n";
Array.iteri (fun i s ->
info " state %d: %s\n"
i (show_pattern s.seen)) sa;
info "-------------\nstatemap:\n";
Test.print_sm stderr sm;
info "-------------\n";
end;
info "generating matchers...\n";
let matchers =
List.map (fun rname ->
info "+ %s...%!" rname;
let m = lr_matcher sm sa rules rname in
let vars = Hashtbl.find varsmap rname in
info " %d nodes\n" (Action.size m);
info ~level:2 " -------------\n";
info ~level:2 " automaton:\n";
info ~level:2 "%s\n"
(Format.asprintf " @[%a@]" Action.pp m);
info ~level:2 " ----------\n";
(vars, rname, m)
) rnames
in
if fuzz then begin
info ~level:0 "fuzzing statemap...\n";
let tp = Fuzz.fuzz_numberer rules numbr in
info ~level:0 "testing %d patterns...\n"
(List.length rules);
Fuzz.test_matchers tp numbr rules
end;
info "emitting C...\n";
flush stderr;
let cgopts =
{ pfx = ""; static = true; oc = oc } in
emit_c cgopts numbr;
emit_matchers cgopts matchers;
()
let read_all ic =
let bufsz = 4096 in
let buf = Bytes.create bufsz in
let data = Buffer.create bufsz in
let read = ref 0 in
while
read := input ic buf 0 bufsz;
!read <> 0
do
Buffer.add_subbytes data buf 0 !read
done;
Buffer.contents data
let split_c src =
let begin_re, eoc_re, end_re =
let re = Str.regexp in
( re "mgen generated code"
, re "\\*/"
, re "end of generated code" )
in
let str_match regexp str =
try
let _: int =
Str.search_forward regexp str 0
in true
with Not_found -> false
in
let rec go st lofs pfx rules lines =
let line, lines =
match lines with
| [] ->
failwith (
match st with
| `Prefix -> "could not find mgen section"
| `Rules -> "mgen rules not terminated"
| `Skip -> "mgen section not terminated"
)
| l :: ls -> (l, ls)
in
match st with
| `Prefix ->
let pfx = line :: pfx in
if str_match begin_re line
then
let lofs = List.length pfx in
go `Rules lofs pfx rules lines
else go `Prefix 0 pfx rules lines
| `Rules ->
let pfx = line :: pfx in
if str_match eoc_re line
then go `Skip lofs pfx rules lines
else go `Rules lofs pfx (line :: rules) lines
| `Skip ->
if str_match end_re line then
let join = String.concat "\n" in
let pfx = join (List.rev pfx) ^ "\n\n"
and rules = join (List.rev rules)
and sfx = join (line :: lines)
in (lofs, pfx, rules, sfx)
else go `Skip lofs pfx rules lines
in
let lines = String.split_on_char '\n' src in
go `Prefix 0 [] [] lines
let () =
let usage_msg =
"mgen [--fuzz] [--verbose <N>] <file>" in
let fuzz_arg = ref false in
let verbose_arg = ref 0 in
let input_paths = ref [] in
let anon_fun filename =
input_paths := filename :: !input_paths in
let speclist =
[ ( "--fuzz", Arg.Set fuzz_arg
, " Fuzz tables and matchers" )
; ( "--verbose", Arg.Set_int verbose_arg
, "<N> Set verbosity level" )
; ( "--", Arg.Rest_all (List.iter anon_fun)
, " Stop argument parsing" ) ]
in
Arg.parse speclist anon_fun usage_msg;
let input_paths = !input_paths in
let verbose = !verbose_arg in
let fuzz = !fuzz_arg in
let input_path, input =
match input_paths with
| ["-"] -> ("-", read_all stdin)
| [path] -> (path, read_all (open_in path))
| _ ->
Printf.eprintf
"%s: single input file expected\n"
Sys.argv.(0);
Arg.usage speclist usage_msg; exit 1
in
let mgen = mgen ~verbose ~fuzz in
if Str.last_chars input_path 2 <> ".c"
then mgen input_path 0 input stdout
else
let tmp_path = input_path ^ ".tmp" in
Fun.protect
~finally:(fun () ->
try Sys.remove tmp_path with _ -> ())
(fun () ->
let lofs, pfx, rules, sfx = split_c input in
let oc = open_out tmp_path in
output_string oc pfx;
mgen input_path lofs rules oc;
output_string oc sfx;
close_out oc;
Sys.rename tmp_path input_path;
());
()

651
tools/mgen/match.ml Normal file
View file

@ -0,0 +1,651 @@
type cls = Kw | Kl | Ks | Kd
type op_base =
| Oadd
| Osub
| Omul
| Oor
| Oshl
| Oshr
type op = cls * op_base
let op_bases =
[Oadd; Osub; Omul; Oor; Oshl; Oshr]
let commutative = function
| (_, (Oadd | Omul | Oor)) -> true
| (_, _) -> false
let associative = function
| (_, (Oadd | Omul | Oor)) -> true
| (_, _) -> false
type atomic_pattern =
| Tmp
| AnyCon
| Con of int64
(* Tmp < AnyCon < Con k *)
type pattern =
| Bnr of op * pattern * pattern
| Atm of atomic_pattern
| Var of string * atomic_pattern
let is_atomic = function
| (Atm _ | Var _) -> true
| _ -> false
let show_op_base o =
match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul"
| Oor -> "or"
| Oshl -> "shl"
| Oshr -> "shr"
let show_op (k, o) =
show_op_base o ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Var (v, p) ->
show_pattern (Atm p) ^ "'" ^ v
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let get_atomic p =
match p with
| (Atm a | Var (_, a)) -> Some a
| _ -> None
let rec pattern_match p w =
match p with
| Var (_, p) ->
pattern_match (Atm p) w
| Atm Tmp ->
begin match get_atomic w with
| Some (Con _ | AnyCon) -> false
| _ -> true
end
| Atm (Con _) -> w = p
| Atm (AnyCon) ->
not (pattern_match (Atm Tmp) w)
| Bnr (o, pl, pr) ->
begin match w with
| Bnr (o', wl, wr) ->
o' = o &&
pattern_match pl wl &&
pattern_match pr wr
| _ -> false
end
type +'a cursor = (* a position inside a pattern *)
| Bnrl of op * 'a cursor * pattern
| Bnrr of op * pattern * 'a cursor
| Top of 'a
let rec fold_cursor c p =
match c with
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
| Top _ -> p
let peel p x =
let once out (p, c) =
match p with
| Var (_, p) -> (Atm p, c) :: out
| Atm _ -> (p, c) :: out
| Bnr (o, pl, pr) ->
(pl, Bnrl (o, c, pr)) ::
(pr, Bnrr (o, pl, c)) :: out
in
let rec go l =
let l' = List.fold_left once [] l in
if List.length l' = List.length l
then l'
else go l'
in go [(p, Top x)]
let fold_pairs l1 l2 ini f =
let rec go acc = function
| [] -> acc
| a :: l1' ->
go (List.fold_left
(fun acc b -> f (a, b) acc)
acc l2) l1'
in go ini l1
let iter_pairs l f =
fold_pairs l l () (fun x () -> f x)
let inverse l =
List.map (fun (a, b) -> (b, a)) l
type 'a state =
{ id: int
; seen: pattern
; point: ('a cursor) list }
let rec binops side {point; _} =
List.filter_map (fun c ->
match c, side with
| Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> None)
point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
compare a b) l |>
List.fold_left (fun (oo, l, res) (o', c) ->
match oo with
| None -> (Some o', [c], [])
| Some o when o = o' -> (oo, c :: l, res)
| Some o -> (Some o', [c], (o, l) :: res))
(None, [], []) |>
(function
| (None, _, _) -> []
| (Some o, l, res) -> (o, l) :: res)
let sort_uniq cmp l =
List.fast_sort cmp l |>
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e when cmp e e' = 0 -> (eo, l)
| Some e -> (Some e', e :: l))
(None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
let setify l =
sort_uniq compare l
let normalize (point: ('a cursor) list) =
setify point
let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst in
let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst in
List.map (fun (o, l) ->
o,
{ id = -1
; seen = Bnr (o, s1.seen, s2.seen)
; point = normalize (l @ tmp) })
(group_by_fst (o1 @ o2))
type p = string
module StateSet : sig
type t
val create: unit -> t
val add: t -> p state ->
[> `Added | `Found ] * p state
val iter: t -> (p state -> unit) -> unit
val elems: t -> (p state) list
end = struct
open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 0 }
let add set s =
assert (s.point = normalize s.point);
try
let id = find set.h s in
`Found, {s with id}
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
add set.h s id;
`Added, {s with id}
end
let iter set f =
let f s id = f {s with id} in
iter f set.h
let elems set =
let res = ref [] in
iter set (fun s -> res := s :: !res);
!res
end
type table_key =
| K of op * p state * p state
module StateMap = struct
include Map.Make(struct
type t = table_key
let compare ka kb =
match ka, kb with
| K (o, sl, sr), K (o', sl', sr') ->
compare (o, sl.id, sr.id)
(o', sl'.id, sr'.id)
end)
let invert n sm =
let rmap = Array.make n [] in
iter (fun k {id; _} ->
match k with
| K (o, sl, sr) ->
rmap.(id) <-
(o, (sl.id, sr.id)) :: rmap.(id)
) sm;
Array.map group_by_fst rmap
let by_ops sm =
fold (fun tk s ops ->
match tk with
| K (op, l, r) ->
(op, ((l.id, r.id), s.id)) :: ops)
sm [] |> group_by_fst
end
type rule =
{ name: string
; vars: string list
; pattern: pattern }
let generate_table rl =
let states = StateSet.create () in
let rl =
(* these atomic patterns must occur in
* rules so that we are able to number
* all possible refs *)
[ { name = "$"; vars = []
; pattern = Atm AnyCon }
; { name = "%"; vars = []
; pattern = Atm Tmp } ] @ rl
in
(* initialize states *)
let ground =
List.concat_map
(fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let tmp = List.assoc (Atm Tmp) ground in
let con = List.assoc (Atm AnyCon) ground in
let atoms = ref [] in
let () =
List.iter (fun (seen, l) ->
let point =
if pattern_match (Atm Tmp) seen
then normalize (tmp @ l)
else normalize (con @ l)
in
let s = {id = -1; seen; point} in
let _, s = StateSet.add states s in
match get_atomic seen with
| Some atm -> atoms := (atm, s) :: !atoms
| None -> ()
) ground
in
(* setup loop state *)
let map = ref StateMap.empty in
let map_add k s' =
map := StateMap.add k s' !map
in
let flag = ref `Added in
let flagmerge = function
| `Added -> flag := `Added
| _ -> ()
in
(* iterate until fixpoint *)
while !flag = `Added do
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in
flagmerge flag';
map_add (K (o, sl, sr)) s';
));
done;
let states =
StateSet.elems states |>
List.sort (fun s s' -> compare s.id s'.id) |>
Array.of_list
in
(states, !atoms, !map)
let intersperse x l =
let rec go left right out =
let out =
(List.rev left @ [x] @ right) ::
out in
match right with
| x :: right' ->
go (x :: left) right' out
| [] -> out
in go [] l []
let rec permute = function
| [] -> [[]]
| x :: l ->
List.concat (List.map
(intersperse x) (permute l))
(* build all binary trees with ordered
* leaves l *)
let rec bins build l =
let rec go l r out =
match r with
| [] -> out
| x :: r' ->
go (l @ [x]) r'
(fold_pairs
(bins build l)
(bins build r)
out (fun (l, r) out ->
build l r :: out))
in
match l with
| [] -> []
| [x] -> [x]
| x :: l -> go [x] l []
let products l ini f =
let rec go acc la = function
| [] -> f (List.rev la) acc
| xs :: l ->
List.fold_left (fun acc x ->
go acc (x :: la) l)
acc xs
in go ini [] l
(* combinatorial nuke... *)
let rec ac_equiv =
let rec alevel o = function
| Bnr (o', l, r) when o' = o ->
alevel o l @ alevel o r
| x -> [x]
in function
| Bnr (o, _, _) as p
when associative o ->
products
(List.map ac_equiv (alevel o p)) []
(fun choice out ->
List.concat_map
(bins (fun l r -> Bnr (o, l, r)))
(if commutative o
then permute choice
else [choice]) @ out)
| Bnr (o, l, r)
when commutative o ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) ::
Bnr (o, r, l) :: out)
| Bnr (o, l, r) ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) :: out)
| x -> [x]
module Action: sig
type node =
| Switch of (int * t) list
| Push of bool * t
| Pop of t
| Set of string * t
| Stop
and t = private
{ id: int; node: node }
val equal: t -> t -> bool
val size: t -> int
val stop: t
val mk_push: sym:bool -> t -> t
val mk_pop: t -> t
val mk_set: string -> t -> t
val mk_switch: int list -> (int -> t) -> t
val pp: Format.formatter -> t -> unit
end = struct
type node =
| Switch of (int * t) list
| Push of bool * t
| Pop of t
| Set of string * t
| Stop
and t =
{ id: int; node: node }
let equal a a' = a.id = a'.id
let size a =
let seen = Hashtbl.create 10 in
let rec node_size = function
| Switch l ->
List.fold_left
(fun n (_, a) -> n + size a) 0 l
| (Push (_, a) | Pop a | Set (_, a)) ->
size a
| Stop -> 0
and size {id; node} =
if Hashtbl.mem seen id
then 0
else begin
Hashtbl.add seen id ();
1 + node_size node
end
in
size a
let mk =
let hcons = Hashtbl.create 100 in
let fresh = ref 0 in
fun node ->
let id =
try Hashtbl.find hcons node
with Not_found ->
let id = !fresh in
Hashtbl.add hcons node id;
fresh := id + 1;
id
in
{id; node}
let stop = mk Stop
let mk_push ~sym a = mk (Push (sym, a))
let mk_pop a =
match a.node with
| Stop -> a
| _ -> mk (Pop a)
let mk_set v a = mk (Set (v, a))
let mk_switch ids f =
match List.map f ids with
| [] -> failwith "empty switch";
| c :: cs as cases ->
if List.for_all (equal c) cs then c
else
let cases = List.combine ids cases in
mk (Switch cases)
open Format
let rec pp_node fmt = function
| Switch l ->
fprintf fmt "@[<v>@[<v2>switch{";
let pp_case (c, a) =
let pp_sep fmt () = fprintf fmt "," in
fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
(pp_print_list ~pp_sep pp_print_int)
c pp a
in
inverse l |> group_by_fst |> inverse |>
List.iter pp_case;
fprintf fmt "@]@,}@]"
| Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
| Push (false, a) -> fprintf fmt "push@ %a" pp a
| Pop a -> fprintf fmt "pop@ %a" pp a
| Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
| Stop -> fprintf fmt ""
and pp fmt a = pp_node fmt a.node
end
(* a state is commutative if (a op b) enters
* it iff (b op a) enters it as well *)
let symmetric rmap id =
List.for_all (fun (_, l) ->
let l1, l2 =
List.filter (fun (a, b) -> a <> b) l |>
List.partition (fun (a, b) -> a < b)
in
setify l1 = setify (inverse l2))
rmap.(id)
(* left-to-right matching of a set of patterns;
* may raise if there is no lr matcher for the
* input rule *)
let lr_matcher statemap states rules name =
let rmap =
let nstates = Array.length states in
StateMap.invert nstates statemap
in
let exception Stuck in
(* the list of ids represents a class of terms
* whose root ends up being labelled with one
* such id; the gen function generates a matcher
* that will, given any such term, assign values
* for the Var nodes of one pattern in pats *)
let rec gen
: 'a. int list -> (pattern * 'a) list
-> (int -> (pattern * 'a) list -> Action.t)
-> Action.t
= fun ids pats k ->
Action.mk_switch (setify ids) @@ fun id_top ->
let sym = symmetric rmap id_top in
let id_ops =
if sym then
let ordered (a, b) = a <= b in
List.map (fun (o, l) ->
(o, List.filter ordered l))
rmap.(id_top)
else rmap.(id_top)
in
(* consider only the patterns that are
* compatible with the current id *)
let atm_pats, bin_pats =
List.filter (function
| Bnr (o, _, _), _ ->
List.exists
(fun (o', _) -> o' = o)
id_ops
| _ -> true) pats |>
List.partition
(fun (pat, _) -> is_atomic pat)
in
try
if bin_pats = [] then raise Stuck;
let pats_l =
List.map (function
| (Bnr (o, l, r), x) ->
(l, (o, x, r))
| _ -> assert false)
bin_pats
and pats_r =
List.map (fun (l, (o, x, r)) ->
(r, (o, l, x)))
and patstop =
List.map (fun (r, (o, l, x)) ->
(Bnr (o, l, r), x))
in
let id_pairs = List.concat_map snd id_ops in
let ids_l = List.map fst id_pairs
and ids_r id_left =
List.filter_map (fun (l, r) ->
if l = id_left then Some r else None)
id_pairs
in
(* match the left arm *)
Action.mk_push ~sym
(gen ids_l pats_l
@@ fun lid pats ->
(* then the right arm, considering
* only the remaining possible
* patterns and knowing that the
* left arm was numbered 'lid' *)
Action.mk_pop
(gen (ids_r lid) (pats_r pats)
@@ fun _rid pats ->
(* continue with the parent *)
k id_top (patstop pats)))
with Stuck ->
let atm_pats =
let seen = states.(id_top).seen in
List.filter (fun (pat, _) ->
pattern_match pat seen) atm_pats
in
if atm_pats = [] then raise Stuck else
let vars =
List.filter_map (function
| (Var (v, _), _) -> Some v
| _ -> None) atm_pats |> setify
in
match vars with
| [] -> k id_top atm_pats
| [v] -> Action.mk_set v (k id_top atm_pats)
| _ -> failwith "ambiguous var match"
in
(* generate a matcher for the rule *)
let ids_top =
Array.to_list states |>
List.filter_map (fun {id; point = p; _} ->
if List.exists ((=) (Top name)) p then
Some id
else None)
in
let rec filter_dups pats =
match pats with
| p :: pats ->
if List.exists (pattern_match p) pats
then filter_dups pats
else p :: filter_dups pats
| [] -> []
in
let pats_top =
List.filter_map (fun r ->
if r.name = name then
Some r.pattern
else None) rules |>
filter_dups |>
List.map (fun p -> (p, ()))
in
gen ids_top pats_top (fun _ pats ->
assert (pats <> []);
Action.stop)
type numberer =
{ atoms: (atomic_pattern * p state) list
; statemap: p state StateMap.t
; states: p state array
; mutable ops: op list
(* memoizes the list of possible operations
* according to the statemap *) }
let make_numberer sa am sm =
{ atoms = am
; states = sa
; statemap = sm
; ops = [] }
let atom_state n atm =
List.assoc atm n.atoms

292
tools/mgen/sexp.ml Normal file
View file

@ -0,0 +1,292 @@
type pstate =
{ data: string
; line: int
; coln: int
; indx: int }
type perror =
{ error: string
; ps: pstate }
exception ParseError of perror
type 'a parser =
{ fn: 'r. pstate -> ('a -> pstate -> 'r) -> 'r }
let update_pos ps beg fin =
let l, c = (ref ps.line, ref ps.coln) in
for i = beg to fin - 1 do
if ps.data.[i] = '\n' then
(incr l; c := 0)
else
incr c
done;
{ ps with line = !l; coln = !c }
let pret (type a) (x: a): a parser =
let fn ps k = k x ps in { fn }
let pfail error: 'a parser =
let fn ps _ = raise (ParseError {error; ps})
in { fn }
let por: 'a parser -> 'a parser -> 'a parser =
fun p1 p2 ->
let fn ps k =
try p1.fn ps k with ParseError e1 ->
try p2.fn ps k with ParseError e2 ->
if e1.ps.indx > e2.ps.indx then
raise (ParseError e1)
else
raise (ParseError e2)
in { fn }
let pbind: 'a parser -> ('a -> 'b parser) -> 'b parser =
fun p1 p2 ->
let fn ps k =
p1.fn ps (fun x ps -> (p2 x).fn ps k)
in { fn }
(* handy for recursive rules *)
let papp p x = pbind (pret x) p
let psnd: 'a parser -> 'b parser -> 'b parser =
fun p1 p2 -> pbind p1 (fun _x -> p2)
let pfst: 'a parser -> 'b parser -> 'a parser =
fun p1 p2 -> pbind p1 (fun x -> psnd p2 (pret x))
module Infix = struct
let ( let* ) = pbind
let ( ||| ) = por
let ( |<< ) = pfst
let ( |>> ) = psnd
end
open Infix
let pre: ?what:string -> string -> string parser =
fun ?what re ->
let what =
match what with
| None -> Printf.sprintf "%S" re
| Some what -> what
and re = Str.regexp re in
let fn ps k =
if not (Str.string_match re ps.data ps.indx) then
(let error =
Printf.sprintf "expected to match %s" what in
raise (ParseError {error; ps}));
let ps =
let indx = Str.match_end () in
{ (update_pos ps ps.indx indx) with indx }
in
k (Str.matched_string ps.data) ps
in { fn }
let peoi: unit parser =
let fn ps k =
if ps.indx <> String.length ps.data then
raise (ParseError
{ error = "expected end of input"; ps });
k () ps
in { fn }
let pws = pre "[ \r\n\t*]*"
let pws1 = pre "[ \r\n\t*]+"
let pthen p1 p2 =
let* x1 = p1 in
let* x2 = p2 in
pret (x1, x2)
let rec plist_tail: 'a parser -> ('a list) parser =
fun pitem ->
(pws |>> pre ")" |>> pret []) |||
(let* itm = pitem in
let* itms = plist_tail pitem in
pret (itm :: itms))
let plist pitem =
pws |>> pre ~what:"a list" "("
|>> plist_tail pitem
let plist1p p1 pitem =
pws |>> pre ~what:"a list" "("
|>> pthen p1 (plist_tail pitem)
let ppair p1 p2 =
pws |>> pre ~what:"a pair" "("
|>> pthen p1 p2 |<< pws |<< pre ")"
let run_parser p s =
let ps =
{data = s; line = 1; coln = 0; indx = 0} in
try `Ok (p.fn ps (fun res _ps -> res))
with ParseError e ->
let rec bol i =
if i = 0 then i else
if i < String.length s && s.[i] = '\n'
then i+1 (* XXX BUG *)
else bol (i-1)
in
let rec eol i =
if i = String.length s then i else
if s.[i] = '\n' then i else
eol (i+1)
in
let bol = bol e.ps.indx in
let eol = eol e.ps.indx in
(*
Printf.eprintf "bol:%d eol:%d indx:%d len:%d\n"
bol eol e.ps.indx (String.length s); (* XXX debug *)
*)
let lines =
String.split_on_char '\n'
(String.sub s bol (eol - bol))
in
let nl = List.length lines in
let caret = ref (e.ps.indx - bol) in
let msg = ref [] in
let pfx = " > " in
lines |> List.iteri (fun ln l ->
if ln <> nl - 1 || l <> "" then begin
let ll = String.length l + 1 in
msg := (pfx ^ l ^ "\n") :: !msg;
if !caret <= ll then begin
let pad = String.make !caret ' ' in
msg := (pfx ^ pad ^ "^\n") :: !msg;
end;
caret := !caret - ll;
end;
);
`Error
( e.ps, e.error
, String.concat "" (List.rev !msg) )
(* ---------------------------------------- *)
(* pattern parsing *)
(* ---------------------------------------- *)
(* Example syntax:
(with-vars (a b c d)
(patterns
(ob (add (tmp a) (con d)))
(bsm (add (tmp b) (mul (tmp m) (con 2 4 8)))) ))
*)
open Match
let pint64 =
let* s = pre "[-]?[0-9_]+" in
pret (Int64.of_string s)
let pid =
pre ~what:"an identifer"
"[a-zA-Z][a-zA-Z0-9_]*"
let pop_base =
let sob, obs = show_op_base, op_bases in
let* s = pre ~what:"an operator"
(String.concat "\\|" (List.map sob obs))
in pret (List.find (fun o -> s = sob o) obs)
let pop = let* ob = pop_base in pret (Kl, ob)
let rec ppat vs =
let pcons_tail =
let* cs = plist_tail (pws1 |>> pint64) in
match cs with
| [] -> pret [AnyCon]
| _ -> pret (List.map (fun c -> Con c) cs)
in
let pvar =
let* id = pid in
if not (List.mem id vs) then
pfail ("unbound variable: " ^ id)
else
pret id
in
pws |>> (
( let* c = pint64 in pret [Atm (Con c)] )
|||
( pre "(con)" |>> pret [Atm AnyCon] ) |||
( let* cs = pre "(con" |>> pcons_tail in
pret (List.map (fun c -> Atm c) cs) ) |||
( let* v = pre "(con" |>> pws1 |>> pvar in
let* cs = pcons_tail in
pret (List.map (fun c -> Var (v, c)) cs) )
|||
( pre "(tmp)" |>> pret [Atm Tmp] ) |||
( let* v = pre "(tmp" |>> pws1 |>> pvar in
pws |>> pre ")" |>> pret [Var (v, Tmp)] )
|||
( let* (op, rands) =
plist1p (pws |>> pop) (papp ppat vs) in
let nrands = List.length rands in
if nrands < 2 then
pfail ( "binary op requires at least"
^ " two arguments" )
else
let mk x y = Bnr (op, x, y) in
pret
(products rands []
(fun rands pats ->
(* construct a left-heavy tree *)
let r0 = List.hd rands in
let rs = List.tl rands in
List.fold_left mk r0 rs :: pats)) )
)
let pwith_vars ?(vs = []) p =
( let* vs =
pws |>> pre "(with-vars" |>> pws |>>
plist (pws |>> pid)
in pws |>> p vs |<< pws |<< pre ")" )
||| p vs
let ppats =
pwith_vars @@ fun vs ->
pre "(patterns" |>> plist_tail
(pwith_vars ~vs @@ fun vs ->
let* n, ps = ppair pid (ppat vs) in
pret (n, vs, ps))
(* ---------------------------------------- *)
(* tests *)
(* ---------------------------------------- *)
let () =
if false then
let show_patterns ps =
"[" ^ String.concat "; "
(List.map show_pattern ps) ^ "]"
in
let pat s =
Printf.printf "parse %s = " s;
let vars =
[ "foobar"; "a"; "b"; "d"
; "m"; "s"; "x" ]
in
match run_parser (ppat vars) s with
| `Ok p ->
Printf.printf "%s\n" (show_patterns p)
| `Error (_, e, _) ->
Printf.printf "ERROR: %s\n" e
in
pat "42";
pat "(tmp)";
pat "(tmp foobar)";
pat "(con)";
pat "(con 1 2 3)";
pat "(con x 1 2 3)";
pat "(add 1 2)";
pat "(add 1 2 3 4)";
pat "(sub 1 2)";
pat "(sub 1 2 3)";
pat "(tmp unbound_var)";
pat "(add 0)";
pat "(add 1 (add 2 3))";
pat "(add (tmp a) (con d))";
pat "(add (tmp b) (mul (tmp m) (con s 2 4 8)))";
pat "(add (con 1 2) (con 3 4))";
()

134
tools/mgen/test.ml Normal file
View file

@ -0,0 +1,134 @@
open Match
open Fuzz
open Cgen
(* unit tests *)
let test_pattern_match =
let pm = pattern_match
and nm = fun x y -> not (pattern_match x y) in
begin
assert (nm (Atm Tmp) (Atm (Con 42L)));
assert (pm (Atm AnyCon) (Atm (Con 42L)));
assert (nm (Atm (Con 42L)) (Atm AnyCon));
assert (nm (Atm (Con 42L)) (Atm Tmp));
end
let test_peel =
let o = Kw, Oadd in
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
Atm (Con 42L)) in
let l = peel p () in
let () = assert (List.length l = 3) in
let atomic_p (p, _) =
match p with Atm _ -> true | _ -> false in
let () = assert (List.for_all atomic_p l) in
let l = List.map (fun (p, c) -> fold_cursor c p) l in
let () = assert (List.for_all ((=) p) l) in
()
let test_fold_pairs =
let l = [1; 2; 3; 4; 5] in
let p = fold_pairs l l [] (fun a b -> a :: b) in
let () = assert (List.length p = 25) in
let p = sort_uniq compare p in
let () = assert (List.length p = 25) in
()
(* test pattern & state *)
let print_sm oc =
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.fprintf oc
" (%s %d %d) -> %d%s\n"
(show_op o)
sl.id sr.id s'.id top)
let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
let va = Var ("a", Tmp)
and vb = Var ("b", Tmp)
and vc = Var ("c", Tmp)
and vs = Var ("s", Tmp) in
let vars = ["a"; "b"; "c"; "s"] in
let rule name pattern =
List.map
(fun pattern -> {name; vars; pattern})
(ac_equiv pattern)
in
match `X64Addr with
(* ------------------------------- *)
| `X64Addr ->
(* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * m *)
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 2L), vs)))
@
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 4L), vs)))
@
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 8L), vs)))
@ (* b + s *)
rule "bs1" (Bnr (oa, vb, vs))
@ (* o + s * m *)
(* rule "osm" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp))) *) []
@ (* o + b + s *)
rule "obs1" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb), vs))
@ (* o + b + s * m *)
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 2L), vs)))
@
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 4L), vs)))
@
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 8L), vs)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; vars = []
; pattern = Bnr (oa, va, Bnr (oa, vb, vc)) } ] @
[ { name = "add"
; vars = []
; pattern = Bnr (oa, Bnr (oa, va, vb), vc) } ]
(*
let sa, am, sm = generate_table rules
let () =
Array.iteri (fun i s ->
Format.printf "@[state %d: %s@]@."
i (show_pattern s.seen))
sa
let () = print_sm stdout sm; flush stdout
let matcher = lr_matcher sm sa rules "obsm" (* XXX *)
let () = Format.printf "@[<v>%a@]@." Action.pp matcher
let () = Format.printf "@[matcher size: %d@]@." (Action.size matcher)
let numbr = make_numberer sa am sm
let () =
let opts = { pfx = ""
; static = true
; oc = stdout } in
emit_c opts numbr;
emit_matchers opts
[ ( ["b"; "o"; "s"; "m"]
, "obsm"
, matcher ) ]
(*
let tp = fuzz_numberer rules numbr
let () = test_matchers tp numbr rules
*)
*)