fuse ac rules in ins-tree matching

The initial plan was to have one
matcher per ac-variant, but that
leads to way too much generated
code. Instead, we can fuse ac
variants of the rules and have
a smarter matching algorithm to
recover bound variables.
This commit is contained in:
Quentin Carbonneaux 2022-02-11 08:42:28 +01:00
parent 8a5e1c3a23
commit 56e2263ca4
2 changed files with 86 additions and 72 deletions

View file

@ -23,11 +23,32 @@ type pattern =
| Atm of atomic_pattern
| Var of string * atomic_pattern
let show_op (k, o) =
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul") ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Var _ -> failwith "variable not allowed"
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let rec pattern_match p w =
match p with
| Var _ ->
failwith "variable not allowed"
| Atm (Tmp) ->
| Var _ -> failwith "variable not allowed"
| Atm Tmp ->
begin match w with
| Atm (Con _ | AnyCon) -> false
| _ -> true
@ -89,12 +110,12 @@ type 'a state =
; point: ('a cursor) list }
let rec binops side {point; _} =
List.fold_left (fun res c ->
List.filter_map (fun c ->
match c, side with
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
| _ -> res)
[] point
| Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> None)
point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
@ -114,11 +135,9 @@ let sort_uniq cmp l =
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e ->
if cmp e e' = 0
then (eo, l)
else (Some e', e :: l)
) (None, []) |>
| Some e when cmp e e' = 0 -> (eo, l)
| Some e -> (Some e', e :: l))
(None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
@ -126,15 +145,14 @@ let sort_uniq cmp l =
let normalize (point: ('a cursor) list) =
sort_uniq compare point
let nextbnr tmp s1 s2 =
let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst
and o2 = binops `R s2 |>
List.map fst in
let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst
in
List.map fst in
List.map (fun (o, l) ->
o,
{ id = 0
@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
type p = string
module StateSet : sig
type set
val create: unit -> set
val add: set -> p state ->
type t
val create: unit -> t
val add: t -> p state ->
[> `Added | `Found ] * p state
val iter: set -> (p state -> unit) -> unit
val elems: set -> (p state) list
val iter: t -> (p state -> unit) -> unit
val elems: t -> (p state) list
end = struct
include Hashtbl.Make(struct
open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type set =
type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 1 }
let add set s =
(* delete the check later *)
assert (s.point = normalize s.point);
try
let id = find set.h s in
@ -171,6 +188,8 @@ end = struct
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
Printf.printf "adding: %d [%s]\n"
id (show_pattern s.seen);
add set.h s id;
`Added, {s with id}
end
@ -198,17 +217,14 @@ end)
type rule =
{ name: string
; pattern: pattern
(* TODO access pattern *)
}
let generate_table rl =
let states = StateSet.create () in
(* initialize states *)
let ground =
List.fold_left
(fun ini r ->
peel r.pattern r.name @ ini)
[] rl |>
List.concat_map
(fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let find x d l =
@ -242,7 +258,7 @@ let generate_table rl =
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
nextbnr tmp sl sr |>
next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in

View file

@ -46,54 +46,52 @@ let ts =
}
let print_sm =
let op_str (k, o) =
Printf.sprintf "%s%s"
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul")
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
in
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.printf
"(%s %d %d) -> %d\n"
(op_str o)
sl.id sr.id s'.id
)
"(%s %d %d) -> %d%s\n"
(show_op o)
sl.id sr.id s'.id top)
let address_rules =
let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
let rule name pattern =
List.mapi (fun i pattern ->
{ name = Printf.sprintf "%s%d" name (i+1)
; pattern; })
(ac_equiv pattern) in
match `X64Addr with
(* ------------------------------- *)
| `X64Addr ->
let rule name pattern =
List.mapi (fun i pattern ->
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
; pattern })
(ac_equiv pattern) in
(* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * i *)
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
@ (* o + s * i *)
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
@ (* b + o + s * i *)
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * i *)
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* o + s * i *)
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* b + o + s * i *)
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
[ { name = "add"
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
[ { name = "mul"
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
Atm Tmp),
Bnr (oa, Atm Tmp,
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
let sl, sm = generate_table address_rules
let sl, sm = generate_table rules
let s n = List.find (fun {id; _} -> id = n) sl
let () = print_sm sm
(*
let tp0 =
let o = Kw, Oadd in
Bnr (o, Atm Tmp, Atm (Con 0L))
let tp1 =
let o = Kw, Oadd in
Bnr (o, tp0, Atm (Con 1L))
*)