fuse ac rules in ins-tree matching

The initial plan was to have one
matcher per ac-variant, but that
leads to way too much generated
code. Instead, we can fuse ac
variants of the rules and have
a smarter matching algorithm to
recover bound variables.
This commit is contained in:
Quentin Carbonneaux 2022-02-11 08:42:28 +01:00
parent 8a5e1c3a23
commit 56e2263ca4
2 changed files with 86 additions and 72 deletions

View file

@ -23,11 +23,32 @@ type pattern =
| Atm of atomic_pattern
| Var of string * atomic_pattern
let show_op (k, o) =
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul") ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Var _ -> failwith "variable not allowed"
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let rec pattern_match p w =
match p with
| Var _ ->
failwith "variable not allowed"
| Atm (Tmp) ->
| Var _ -> failwith "variable not allowed"
| Atm Tmp ->
begin match w with
| Atm (Con _ | AnyCon) -> false
| _ -> true
@ -89,12 +110,12 @@ type 'a state =
; point: ('a cursor) list }
let rec binops side {point; _} =
List.fold_left (fun res c ->
List.filter_map (fun c ->
match c, side with
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
| _ -> res)
[] point
| Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> None)
point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
@ -114,11 +135,9 @@ let sort_uniq cmp l =
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e ->
if cmp e e' = 0
then (eo, l)
else (Some e', e :: l)
) (None, []) |>
| Some e when cmp e e' = 0 -> (eo, l)
| Some e -> (Some e', e :: l))
(None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
@ -126,15 +145,14 @@ let sort_uniq cmp l =
let normalize (point: ('a cursor) list) =
sort_uniq compare point
let nextbnr tmp s1 s2 =
let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst
and o2 = binops `R s2 |>
List.map fst in
let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst
in
List.map fst in
List.map (fun (o, l) ->
o,
{ id = 0
@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
type p = string
module StateSet : sig
type set
val create: unit -> set
val add: set -> p state ->
type t
val create: unit -> t
val add: t -> p state ->
[> `Added | `Found ] * p state
val iter: set -> (p state -> unit) -> unit
val elems: set -> (p state) list
val iter: t -> (p state -> unit) -> unit
val elems: t -> (p state) list
end = struct
include Hashtbl.Make(struct
open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type set =
type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 1 }
let add set s =
(* delete the check later *)
assert (s.point = normalize s.point);
try
let id = find set.h s in
@ -171,6 +188,8 @@ end = struct
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
Printf.printf "adding: %d [%s]\n"
id (show_pattern s.seen);
add set.h s id;
`Added, {s with id}
end
@ -198,17 +217,14 @@ end)
type rule =
{ name: string
; pattern: pattern
(* TODO access pattern *)
}
let generate_table rl =
let states = StateSet.create () in
(* initialize states *)
let ground =
List.fold_left
(fun ini r ->
peel r.pattern r.name @ ini)
[] rl |>
List.concat_map
(fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let find x d l =
@ -242,7 +258,7 @@ let generate_table rl =
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
nextbnr tmp sl sr |>
next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in

View file

@ -46,54 +46,52 @@ let ts =
}
let print_sm =
let op_str (k, o) =
Printf.sprintf "%s%s"
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul")
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
in
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.printf
"(%s %d %d) -> %d\n"
(op_str o)
sl.id sr.id s'.id
)
"(%s %d %d) -> %d%s\n"
(show_op o)
sl.id sr.id s'.id top)
let address_rules =
let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
match `X64Addr with
(* ------------------------------- *)
| `X64Addr ->
let rule name pattern =
List.mapi (fun i pattern ->
{ name = Printf.sprintf "%s%d" name (i+1)
; pattern; })
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
; pattern })
(ac_equiv pattern) in
(* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * i *)
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* o + s * i *)
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* b + o + s * i *)
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
[ { name = "add"
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
[ { name = "mul"
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
Atm Tmp),
Bnr (oa, Atm Tmp,
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
let sl, sm = generate_table address_rules
let sl, sm = generate_table rules
let s n = List.find (fun {id; _} -> id = n) sl
let () = print_sm sm
(*
let tp0 =
let o = Kw, Oadd in
Bnr (o, Atm Tmp, Atm (Con 0L))
let tp1 =
let o = Kw, Oadd in
Bnr (o, tp0, Atm (Con 1L))
*)