From 56e2263ca46166ffffb814ae225faf08fd52248c Mon Sep 17 00:00:00 2001 From: Quentin Carbonneaux Date: Fri, 11 Feb 2022 08:42:28 +0100 Subject: [PATCH] fuse ac rules in ins-tree matching The initial plan was to have one matcher per ac-variant, but that leads to way too much generated code. Instead, we can fuse ac variants of the rules and have a smarter matching algorithm to recover bound variables. --- tools/match.ml | 80 +++++++++++++++++++++++++++------------------ tools/match_test.ml | 78 +++++++++++++++++++++---------------------- 2 files changed, 86 insertions(+), 72 deletions(-) diff --git a/tools/match.ml b/tools/match.ml index 4aeeae0..5de356b 100644 --- a/tools/match.ml +++ b/tools/match.ml @@ -23,11 +23,32 @@ type pattern = | Atm of atomic_pattern | Var of string * atomic_pattern +let show_op (k, o) = + (match o with + | Oadd -> "add" + | Osub -> "sub" + | Omul -> "mul") ^ + (match k with + | Kw -> "w" + | Kl -> "l" + | Ks -> "s" + | Kd -> "d") + +let rec show_pattern p = + match p with + | Var _ -> failwith "variable not allowed" + | Atm Tmp -> "%" + | Atm AnyCon -> "$" + | Atm (Con n) -> Int64.to_string n + | Bnr (o, pl, pr) -> + "(" ^ show_op o ^ + " " ^ show_pattern pl ^ + " " ^ show_pattern pr ^ ")" + let rec pattern_match p w = match p with - | Var _ -> - failwith "variable not allowed" - | Atm (Tmp) -> + | Var _ -> failwith "variable not allowed" + | Atm Tmp -> begin match w with | Atm (Con _ | AnyCon) -> false | _ -> true @@ -89,12 +110,12 @@ type 'a state = ; point: ('a cursor) list } let rec binops side {point; _} = - List.fold_left (fun res c -> + List.filter_map (fun c -> match c, side with - | Bnrl (o, c, r), `L -> ((o, c), r) :: res - | Bnrr (o, l, c), `R -> ((o, c), l) :: res - | _ -> res) - [] point + | Bnrl (o, c, r), `L -> Some ((o, c), r) + | Bnrr (o, l, c), `R -> Some ((o, c), l) + | _ -> None) + point let group_by_fst l = List.fast_sort (fun (a, _) (b, _) -> @@ -114,11 +135,9 @@ let sort_uniq cmp l = List.fold_left (fun (eo, l) e' -> match eo with | None -> (Some e', l) - | Some e -> - if cmp e e' = 0 - then (eo, l) - else (Some e', e :: l) - ) (None, []) |> + | Some e when cmp e e' = 0 -> (eo, l) + | Some e -> (Some e', e :: l)) + (None, []) |> (function | (None, _) -> [] | (Some e, l) -> List.rev (e :: l)) @@ -126,15 +145,14 @@ let sort_uniq cmp l = let normalize (point: ('a cursor) list) = sort_uniq compare point -let nextbnr tmp s1 s2 = +let next_binary tmp s1 s2 = let pm w (_, p) = pattern_match p w in let o1 = binops `L s1 |> List.filter (pm s2.seen) |> - List.map fst - and o2 = binops `R s2 |> + List.map fst in + let o2 = binops `R s2 |> List.filter (pm s1.seen) |> - List.map fst - in + List.map fst in List.map (fun (o, l) -> o, { id = 0 @@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 = type p = string module StateSet : sig - type set - val create: unit -> set - val add: set -> p state -> + type t + val create: unit -> t + val add: t -> p state -> [> `Added | `Found ] * p state - val iter: set -> (p state -> unit) -> unit - val elems: set -> (p state) list + val iter: t -> (p state -> unit) -> unit + val elems: t -> (p state) list end = struct - include Hashtbl.Make(struct + open Hashtbl.Make(struct type t = p state let equal s1 s2 = s1.point = s2.point let hash s = Hashtbl.hash s.point end) - type set = + type nonrec t = { h: int t ; mutable next_id: int } let create () = { h = create 500; next_id = 1 } let add set s = - (* delete the check later *) assert (s.point = normalize s.point); try let id = find set.h s in @@ -171,6 +188,8 @@ end = struct with Not_found -> begin let id = set.next_id in set.next_id <- id + 1; + Printf.printf "adding: %d [%s]\n" + id (show_pattern s.seen); add set.h s id; `Added, {s with id} end @@ -198,17 +217,14 @@ end) type rule = { name: string ; pattern: pattern - (* TODO access pattern *) } let generate_table rl = let states = StateSet.create () in (* initialize states *) let ground = - List.fold_left - (fun ini r -> - peel r.pattern r.name @ ini) - [] rl |> + List.concat_map + (fun r -> peel r.pattern r.name) rl |> group_by_fst in let find x d l = @@ -242,7 +258,7 @@ let generate_table rl = flag := `Stop; let statel = StateSet.elems states in iter_pairs statel (fun (sl, sr) -> - nextbnr tmp sl sr |> + next_binary tmp sl sr |> List.iter (fun (o, s') -> let flag', s' = StateSet.add states s' in diff --git a/tools/match_test.ml b/tools/match_test.ml index fe740c5..da63666 100644 --- a/tools/match_test.ml +++ b/tools/match_test.ml @@ -46,54 +46,52 @@ let ts = } let print_sm = - let op_str (k, o) = - Printf.sprintf "%s%s" - (match o with - | Oadd -> "add" - | Osub -> "sub" - | Omul -> "mul") - (match k with - | Kw -> "w" - | Kl -> "l" - | Ks -> "s" - | Kd -> "d") - in StateMap.iter (fun k s' -> match k with | K (o, sl, sr) -> + let top = + List.fold_left (fun top c -> + match c with + | Top r -> top ^ " " ^ r + | _ -> top) "" s'.point + in Printf.printf - "(%s %d %d) -> %d\n" - (op_str o) - sl.id sr.id s'.id - ) + "(%s %d %d) -> %d%s\n" + (show_op o) + sl.id sr.id s'.id top) -let address_rules = +let rules = let oa = Kl, Oadd in let om = Kl, Omul in - let rule name pattern = - List.mapi (fun i pattern -> - { name = Printf.sprintf "%s%d" name (i+1) - ; pattern; }) - (ac_equiv pattern) in - + match `X64Addr with + (* ------------------------------- *) + | `X64Addr -> + let rule name pattern = + List.mapi (fun i pattern -> + { name (* = Printf.sprintf "%s%d" name (i+1) *) + ; pattern }) + (ac_equiv pattern) in (* o + b *) - rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon)) - @ (* b + s * i *) - rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp))) - @ (* o + s * i *) - rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp))) - @ (* b + o + s * i *) - rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp))) + rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon)) + @ (* b + s * i *) + rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp))) + @ (* o + s * i *) + rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp))) + @ (* b + o + s * i *) + rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp))) + (* ------------------------------- *) + | `Add3 -> + [ { name = "add" + ; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @ + [ { name = "add" + ; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @ + [ { name = "mul" + ; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), + Atm Tmp), + Bnr (oa, Atm Tmp, + Bnr (oa, Atm Tmp, Atm Tmp))) } ] -let sl, sm = generate_table address_rules + +let sl, sm = generate_table rules let s n = List.find (fun {id; _} -> id = n) sl let () = print_sm sm - -(* -let tp0 = - let o = Kw, Oadd in - Bnr (o, Atm Tmp, Atm (Con 0L)) -let tp1 = - let o = Kw, Oadd in - Bnr (o, tp0, Atm (Con 1L)) -*)