libqbe/tools/mgen/cgen.ml
Quentin Carbonneaux a609527752 mgen: match automatons and C generation
The algorithm to generate matchers
took a long time to be discovered
and refined to its present version.
The rest of mgen is mostly boring
engineering.

Extensive fuzzing ensures that the
two core components of mgen (tables
and matchers generation) are correct
on specific problem instances.
2024-04-09 21:45:42 +02:00

420 lines
11 KiB
OCaml

open Match
type options =
{ pfx: string
; static: bool
; oc: out_channel }
type side = L | R
type id_pred =
| InBitSet of Int64.t
| Ge of int
| Eq of int
and id_test =
| Pred of (side * id_pred)
| And of id_test * id_test
type case_code =
| Table of ((int * int) * int) list
| IfThen of
{ test: id_test
; cif: case_code
; cthen: case_code option }
| Return of int
type case =
{ swap: bool
; code: case_code }
let cgen_case tmp nstates map =
let cgen_test ids =
match ids with
| [id] -> Eq id
| _ ->
let min_id =
List.fold_left min max_int ids in
if List.length ids = nstates - min_id
then Ge min_id
else begin
assert (nstates <= 64);
InBitSet
(List.fold_left (fun bs id ->
Int64.logor bs
(Int64.shift_left 1L id))
0L ids)
end
in
let symmetric =
let inverse ((l, r), x) = ((r, l), x) in
setify map = setify (List.map inverse map) in
let map =
let ordered ((l, r), _) = r <= l in
if symmetric then
List.filter ordered map
else map
in
let exception BailToTable in
try
let st =
match setify (List.map snd map) with
| [st] -> st
| _ -> raise BailToTable
in
(* the operation considered can only
* generate a single state *)
let pairs = List.map fst map in
let ls, rs = List.split pairs in
let ls = setify ls and rs = setify rs in
if List.length ls > 1 && List.length rs > 1 then
raise BailToTable;
{ swap = symmetric
; code =
let pl = Pred (L, cgen_test ls)
and pr = Pred (R, cgen_test rs) in
IfThen
{ test = And (pl, pr)
; cif = Return st
; cthen = Some (Return tmp) } }
with BailToTable ->
{ swap = symmetric
; code = Table map }
let show_op (_cls, op) =
"O" ^ show_op_base op
let indent oc i =
Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
let emit_swap oc i =
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
pfi i "if (l < r)\n";
pfi (i+1) "t = l, l = r, r = t;\n"
let gen_tables oc tmp pfx nstates (op, c) =
let i = 1 in
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
let ntables = ref 0 in
(* we must follow the order in which
* we visit code in emit_case, or
* else ntables goes out of sync *)
let base = pfx ^ show_op op in
let swap = c.swap in
let rec gen c =
match c with
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
assert (nstates <= 256);
if swap then
let n = nstates * (nstates + 1) / 2 in
pfi i "static uchar %stbl[%d] = {\n" name n
else
pfi i "static uchar %stbl[%d][%d] = {\n"
name nstates nstates;
for l = 0 to nstates - 1 do
pfi (i+1) "";
for r = 0 to nstates - 1 do
if not swap || r <= l then
begin
pf "%d"
(try List.assoc (l,r) map
with Not_found -> tmp);
pf ",";
end
done;
pf "\n";
done;
pfi i "};\n"
| IfThen {cif; cthen} ->
gen cif;
Option.iter gen cthen
| Return _ -> ()
in
gen c.code
let emit_case oc pfx no_swap (op, c) =
let fpf = Printf.fprintf in
let pf m = fpf oc m in
let pfi n m = indent oc n; pf m in
let rec side oc = function
| L -> fpf oc "l"
| R -> fpf oc "r"
in
let pred oc (s, pred) =
match pred with
| InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
| Eq id -> fpf oc "%a == %d" side s id
| Ge id -> fpf oc "%d <= %a" id side s
in
let base = pfx ^ show_op op in
let swap = c.swap in
let ntables = ref 0 in
let rec code i c =
match c with
| Return id -> pfi i "return %d;\n" id
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
incr ntables;
if swap then
pfi i "return %stbl[(l + l*l)/2 + r];\n" name
else pfi i "return %stbl[l][r];\n" name
| IfThen ({test = And (And (t1, t2), t3)} as r) ->
code i @@ IfThen
{r with test = And (t1, And (t2, t3))}
| IfThen {test = And (Pred p, t); cif; cthen} ->
pfi i "if (%a)\n" pred p;
code i (IfThen {test = t; cif; cthen})
| IfThen {test = Pred p; cif; cthen} ->
pfi i "if (%a) {\n" pred p;
code (i+1) cif;
pfi i "}\n";
Option.iter (code i) cthen
in
pfi 1 "case %s:\n" (show_op op);
if not no_swap && c.swap then
emit_swap oc 2;
code 2 c.code
let emit_list
?(limit=60) ?(cut_before_sep=false)
~col ~indent:i ~sep ~f oc l =
let sl = String.length sep in
let rstripped_sep, rssl =
if sep.[sl - 1] = ' ' then
String.sub sep 0 (sl - 1), sl - 1
else sep, sl
in
let lstripped_sep, lssl =
if sep.[0] = ' ' then
String.sub sep 1 (sl - 1), sl - 1
else sep, sl
in
let rec line col acc = function
| [] -> (List.rev acc, [])
| s :: l ->
let col = col + sl + String.length s in
let no_space =
if cut_before_sep || l = [] then
col > limit
else
col + rssl > limit
in
if no_space then
(List.rev acc, s :: l)
else
line col (s :: acc) l
in
let rec go col l =
if l = [] then () else
let ll, l = line col [] l in
Printf.fprintf oc "%s" (String.concat sep ll);
if l <> [] && cut_before_sep then begin
Printf.fprintf oc "\n";
indent oc i;
Printf.fprintf oc "%s" lstripped_sep;
go (8*i + lssl) l
end else if l <> [] then begin
Printf.fprintf oc "%s\n" rstripped_sep;
indent oc i;
go (8*i) l
end else ()
in
go col (List.map f l)
let emit_numberer opts n =
let pf m = Printf.fprintf opts.oc m in
let tmp = (atom_state n Tmp).id in
let con = (atom_state n AnyCon).id in
let nst = Array.length n.states in
let cases =
StateMap.by_ops n.statemap |>
List.map (fun (op, map) ->
(op, cgen_case tmp nst map))
in
let all_swap =
List.for_all (fun (_, c) -> c.swap) cases in
(* opn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%sopn(int op, int l, int r)\n" opts.pfx;
pf "{\n";
cases |> List.iter
(gen_tables opts.oc tmp opts.pfx nst);
if List.exists (fun (_, c) -> c.swap) cases then
pf "\tint t;\n\n";
if all_swap then emit_swap opts.oc 1;
pf "\tswitch (op) {\n";
cases |> List.iter
(emit_case opts.oc opts.pfx all_swap);
pf "\tdefault:\n";
pf "\t\treturn %d;\n" tmp;
pf "\t}\n";
pf "}\n\n";
(* refn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
pf "{\n";
let cons =
List.filter_map (function
| (Con c, s) -> Some (c, s.id)
| _ -> None)
n.atoms
in
if cons <> [] then
pf "\tint64_t n;\n\n";
pf "\tswitch (rtype(r)) {\n";
pf "\tcase RTmp:\n";
if tmp <> 0 then begin
assert
(List.exists (fun (_, s) ->
s.id = 0
) n.atoms &&
(* no temp should ever get state 0 *)
List.for_all (fun (a, s) ->
s.id <> 0 ||
match a with
| AnyCon | Con _ -> true
| _ -> false
) n.atoms);
pf "\t\tif (!tn[r.val].n)\n";
pf "\t\t\ttn[r.val].n = %d;\n" tmp;
end;
pf "\t\treturn tn[r.val].n;\n";
pf "\tcase RCon:\n";
if cons <> [] then begin
pf "\t\tif (con[r.val].type != CBits)\n";
pf "\t\t\treturn %d;\n" con;
pf "\t\tn = con[r.val].bits.i;\n";
cons |> inverse |> group_by_fst
|> List.iter (fun (id, cs) ->
pf "\t\tif (";
emit_list ~cut_before_sep:true
~col:20 ~indent:2 ~sep:" || "
~f:(fun c -> "n == " ^ Int64.to_string c)
opts.oc cs;
pf ")\n";
pf "\t\t\treturn %d;\n" id
);
end;
pf "\t\treturn %d;\n" con;
pf "\tdefault:\n";
pf "\t\treturn INT_MIN;\n";
pf "\t}\n";
pf "}\n\n";
(* match[]: patterns per state *)
if opts.static then pf "static ";
pf "bits %smatch[%d] = {\n" opts.pfx nst;
n.states |> Array.iteri (fun sn s ->
let tops =
List.filter_map (function
| Top ("$" | "%") -> None
| Top r -> Some ("BIT(P" ^ r ^ ")")
| _ -> None) s.point |> setify
in
if tops <> [] then
pf "\t[%d] = %s,\n"
sn (String.concat " | " tops);
);
pf "};\n\n"
let var_id vars f =
List.mapi (fun i x -> (x, i)) vars |>
List.assoc f
let compile_action vars act =
let pcs = Hashtbl.create 100 in
let rec gen pc (act: Action.t) =
try
[10 + Hashtbl.find pcs act.id]
with Not_found ->
let code =
match act.node with
| Action.Stop ->
[0]
| Action.Push (sym, k) ->
let c = if sym then 1 else 2 in
[c] @ gen (pc + 1) k
| Action.Set (v, {node = Action.Pop k; _})
| Action.Set (v, ({node = Action.Stop; _} as k)) ->
let v = var_id vars v in
[3; v] @ gen (pc + 2) k
| Action.Set _ ->
(* for now, only atomic patterns can be
* tied to a variable, so Set must be
* followed by either Pop or Stop *)
assert false
| Action.Pop k ->
[4] @ gen (pc + 1) k
| Action.Switch cases ->
let cases =
inverse cases |> group_by_fst |>
List.sort (fun (_, cs1) (_, cs2) ->
let n1 = List.length cs1
and n2 = List.length cs2 in
compare n2 n1)
in
(* the last case is the one with
* the max number of entries *)
let cases = List.rev (List.tl cases)
and last = fst (List.hd cases) in
let ncases =
List.fold_left (fun n (_, cs) ->
List.length cs + n)
0 cases
in
let body_off = 2 + 2 * ncases + 1 in
let pc, tbl, body =
List.fold_left
(fun (pc, tbl, body) (a, cs) ->
let ofs = body_off + List.length body in
let case = gen pc a in
let pc = pc + List.length case in
let body = body @ case in
let tbl =
List.fold_left (fun tbl c ->
tbl @ [c; ofs]
) tbl cs
in
(pc, tbl, body))
(pc + body_off, [], [])
cases
in
let ofs = body_off + List.length body in
let tbl = tbl @ [ofs] in
assert (2 + List.length tbl = body_off);
[5; ncases] @ tbl @ body @ gen pc last
in
if act.node <> Action.Stop then
Hashtbl.replace pcs act.id pc;
code
in
gen 0 act
let emit_matchers opts ms =
let pf m = Printf.fprintf opts.oc m in
if opts.static then pf "static ";
pf "uchar *%smatcher[] = {\n" opts.pfx;
List.iter (fun (vars, pname, m) ->
pf "\t[P%s] = (uchar[]){\n" pname;
pf "\t\t";
let bytes = compile_action vars m in
emit_list
~col:16 ~indent:2 ~sep:","
~f:string_of_int opts.oc bytes;
pf "\n";
pf "\t},\n")
ms;
pf "};\n\n"
let emit_c opts n =
emit_numberer opts n