fuse ac rules in ins-tree matching

The initial plan was to have one
matcher per ac-variant, but that
leads to way too much generated
code. Instead, we can fuse ac
variants of the rules and have
a smarter matching algorithm to
recover bound variables.
This commit is contained in:
Quentin Carbonneaux 2022-02-11 08:42:28 +01:00
parent 8a5e1c3a23
commit 56e2263ca4
2 changed files with 86 additions and 72 deletions

View file

@ -23,11 +23,32 @@ type pattern =
| Atm of atomic_pattern | Atm of atomic_pattern
| Var of string * atomic_pattern | Var of string * atomic_pattern
let show_op (k, o) =
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul") ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Var _ -> failwith "variable not allowed"
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let rec pattern_match p w = let rec pattern_match p w =
match p with match p with
| Var _ -> | Var _ -> failwith "variable not allowed"
failwith "variable not allowed" | Atm Tmp ->
| Atm (Tmp) ->
begin match w with begin match w with
| Atm (Con _ | AnyCon) -> false | Atm (Con _ | AnyCon) -> false
| _ -> true | _ -> true
@ -89,12 +110,12 @@ type 'a state =
; point: ('a cursor) list } ; point: ('a cursor) list }
let rec binops side {point; _} = let rec binops side {point; _} =
List.fold_left (fun res c -> List.filter_map (fun c ->
match c, side with match c, side with
| Bnrl (o, c, r), `L -> ((o, c), r) :: res | Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> ((o, c), l) :: res | Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> res) | _ -> None)
[] point point
let group_by_fst l = let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) -> List.fast_sort (fun (a, _) (b, _) ->
@ -114,11 +135,9 @@ let sort_uniq cmp l =
List.fold_left (fun (eo, l) e' -> List.fold_left (fun (eo, l) e' ->
match eo with match eo with
| None -> (Some e', l) | None -> (Some e', l)
| Some e -> | Some e when cmp e e' = 0 -> (eo, l)
if cmp e e' = 0 | Some e -> (Some e', e :: l))
then (eo, l) (None, []) |>
else (Some e', e :: l)
) (None, []) |>
(function (function
| (None, _) -> [] | (None, _) -> []
| (Some e, l) -> List.rev (e :: l)) | (Some e, l) -> List.rev (e :: l))
@ -126,15 +145,14 @@ let sort_uniq cmp l =
let normalize (point: ('a cursor) list) = let normalize (point: ('a cursor) list) =
sort_uniq compare point sort_uniq compare point
let nextbnr tmp s1 s2 = let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |> let o1 = binops `L s1 |>
List.filter (pm s2.seen) |> List.filter (pm s2.seen) |>
List.map fst List.map fst in
and o2 = binops `R s2 |> let o2 = binops `R s2 |>
List.filter (pm s1.seen) |> List.filter (pm s1.seen) |>
List.map fst List.map fst in
in
List.map (fun (o, l) -> List.map (fun (o, l) ->
o, o,
{ id = 0 { id = 0
@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
type p = string type p = string
module StateSet : sig module StateSet : sig
type set type t
val create: unit -> set val create: unit -> t
val add: set -> p state -> val add: t -> p state ->
[> `Added | `Found ] * p state [> `Added | `Found ] * p state
val iter: set -> (p state -> unit) -> unit val iter: t -> (p state -> unit) -> unit
val elems: set -> (p state) list val elems: t -> (p state) list
end = struct end = struct
include Hashtbl.Make(struct open Hashtbl.Make(struct
type t = p state type t = p state
let equal s1 s2 = s1.point = s2.point let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point let hash s = Hashtbl.hash s.point
end) end)
type set = type nonrec t =
{ h: int t { h: int t
; mutable next_id: int } ; mutable next_id: int }
let create () = let create () =
{ h = create 500; next_id = 1 } { h = create 500; next_id = 1 }
let add set s = let add set s =
(* delete the check later *)
assert (s.point = normalize s.point); assert (s.point = normalize s.point);
try try
let id = find set.h s in let id = find set.h s in
@ -171,6 +188,8 @@ end = struct
with Not_found -> begin with Not_found -> begin
let id = set.next_id in let id = set.next_id in
set.next_id <- id + 1; set.next_id <- id + 1;
Printf.printf "adding: %d [%s]\n"
id (show_pattern s.seen);
add set.h s id; add set.h s id;
`Added, {s with id} `Added, {s with id}
end end
@ -198,17 +217,14 @@ end)
type rule = type rule =
{ name: string { name: string
; pattern: pattern ; pattern: pattern
(* TODO access pattern *)
} }
let generate_table rl = let generate_table rl =
let states = StateSet.create () in let states = StateSet.create () in
(* initialize states *) (* initialize states *)
let ground = let ground =
List.fold_left List.concat_map
(fun ini r -> (fun r -> peel r.pattern r.name) rl |>
peel r.pattern r.name @ ini)
[] rl |>
group_by_fst group_by_fst
in in
let find x d l = let find x d l =
@ -242,7 +258,7 @@ let generate_table rl =
flag := `Stop; flag := `Stop;
let statel = StateSet.elems states in let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) -> iter_pairs statel (fun (sl, sr) ->
nextbnr tmp sl sr |> next_binary tmp sl sr |>
List.iter (fun (o, s') -> List.iter (fun (o, s') ->
let flag', s' = let flag', s' =
StateSet.add states s' in StateSet.add states s' in

View file

@ -46,54 +46,52 @@ let ts =
} }
let print_sm = let print_sm =
let op_str (k, o) =
Printf.sprintf "%s%s"
(match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul")
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
in
StateMap.iter (fun k s' -> StateMap.iter (fun k s' ->
match k with match k with
| K (o, sl, sr) -> | K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.printf Printf.printf
"(%s %d %d) -> %d\n" "(%s %d %d) -> %d%s\n"
(op_str o) (show_op o)
sl.id sr.id s'.id sl.id sr.id s'.id top)
)
let address_rules = let rules =
let oa = Kl, Oadd in let oa = Kl, Oadd in
let om = Kl, Omul in let om = Kl, Omul in
let rule name pattern = match `X64Addr with
List.mapi (fun i pattern -> (* ------------------------------- *)
{ name = Printf.sprintf "%s%d" name (i+1) | `X64Addr ->
; pattern; }) let rule name pattern =
(ac_equiv pattern) in List.mapi (fun i pattern ->
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
; pattern })
(ac_equiv pattern) in
(* o + b *) (* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon)) rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * i *) @ (* b + s * i *)
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp))) rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* o + s * i *) @ (* o + s * i *)
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp))) rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
@ (* b + o + s * i *) @ (* b + o + s * i *)
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp))) rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
[ { name = "add"
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
[ { name = "mul"
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
Atm Tmp),
Bnr (oa, Atm Tmp,
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
let sl, sm = generate_table address_rules
let sl, sm = generate_table rules
let s n = List.find (fun {id; _} -> id = n) sl let s n = List.find (fun {id; _} -> id = n) sl
let () = print_sm sm let () = print_sm sm
(*
let tp0 =
let o = Kw, Oadd in
Bnr (o, Atm Tmp, Atm (Con 0L))
let tp1 =
let o = Kw, Oadd in
Bnr (o, tp0, Atm (Con 1L))
*)