fuse ac rules in ins-tree matching
The initial plan was to have one matcher per ac-variant, but that leads to way too much generated code. Instead, we can fuse ac variants of the rules and have a smarter matching algorithm to recover bound variables.
This commit is contained in:
parent
8a5e1c3a23
commit
56e2263ca4
2 changed files with 86 additions and 72 deletions
|
@ -23,11 +23,32 @@ type pattern =
|
||||||
| Atm of atomic_pattern
|
| Atm of atomic_pattern
|
||||||
| Var of string * atomic_pattern
|
| Var of string * atomic_pattern
|
||||||
|
|
||||||
|
let show_op (k, o) =
|
||||||
|
(match o with
|
||||||
|
| Oadd -> "add"
|
||||||
|
| Osub -> "sub"
|
||||||
|
| Omul -> "mul") ^
|
||||||
|
(match k with
|
||||||
|
| Kw -> "w"
|
||||||
|
| Kl -> "l"
|
||||||
|
| Ks -> "s"
|
||||||
|
| Kd -> "d")
|
||||||
|
|
||||||
|
let rec show_pattern p =
|
||||||
|
match p with
|
||||||
|
| Var _ -> failwith "variable not allowed"
|
||||||
|
| Atm Tmp -> "%"
|
||||||
|
| Atm AnyCon -> "$"
|
||||||
|
| Atm (Con n) -> Int64.to_string n
|
||||||
|
| Bnr (o, pl, pr) ->
|
||||||
|
"(" ^ show_op o ^
|
||||||
|
" " ^ show_pattern pl ^
|
||||||
|
" " ^ show_pattern pr ^ ")"
|
||||||
|
|
||||||
let rec pattern_match p w =
|
let rec pattern_match p w =
|
||||||
match p with
|
match p with
|
||||||
| Var _ ->
|
| Var _ -> failwith "variable not allowed"
|
||||||
failwith "variable not allowed"
|
| Atm Tmp ->
|
||||||
| Atm (Tmp) ->
|
|
||||||
begin match w with
|
begin match w with
|
||||||
| Atm (Con _ | AnyCon) -> false
|
| Atm (Con _ | AnyCon) -> false
|
||||||
| _ -> true
|
| _ -> true
|
||||||
|
@ -89,12 +110,12 @@ type 'a state =
|
||||||
; point: ('a cursor) list }
|
; point: ('a cursor) list }
|
||||||
|
|
||||||
let rec binops side {point; _} =
|
let rec binops side {point; _} =
|
||||||
List.fold_left (fun res c ->
|
List.filter_map (fun c ->
|
||||||
match c, side with
|
match c, side with
|
||||||
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
|
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||||
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
|
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||||
| _ -> res)
|
| _ -> None)
|
||||||
[] point
|
point
|
||||||
|
|
||||||
let group_by_fst l =
|
let group_by_fst l =
|
||||||
List.fast_sort (fun (a, _) (b, _) ->
|
List.fast_sort (fun (a, _) (b, _) ->
|
||||||
|
@ -114,11 +135,9 @@ let sort_uniq cmp l =
|
||||||
List.fold_left (fun (eo, l) e' ->
|
List.fold_left (fun (eo, l) e' ->
|
||||||
match eo with
|
match eo with
|
||||||
| None -> (Some e', l)
|
| None -> (Some e', l)
|
||||||
| Some e ->
|
| Some e when cmp e e' = 0 -> (eo, l)
|
||||||
if cmp e e' = 0
|
| Some e -> (Some e', e :: l))
|
||||||
then (eo, l)
|
(None, []) |>
|
||||||
else (Some e', e :: l)
|
|
||||||
) (None, []) |>
|
|
||||||
(function
|
(function
|
||||||
| (None, _) -> []
|
| (None, _) -> []
|
||||||
| (Some e, l) -> List.rev (e :: l))
|
| (Some e, l) -> List.rev (e :: l))
|
||||||
|
@ -126,15 +145,14 @@ let sort_uniq cmp l =
|
||||||
let normalize (point: ('a cursor) list) =
|
let normalize (point: ('a cursor) list) =
|
||||||
sort_uniq compare point
|
sort_uniq compare point
|
||||||
|
|
||||||
let nextbnr tmp s1 s2 =
|
let next_binary tmp s1 s2 =
|
||||||
let pm w (_, p) = pattern_match p w in
|
let pm w (_, p) = pattern_match p w in
|
||||||
let o1 = binops `L s1 |>
|
let o1 = binops `L s1 |>
|
||||||
List.filter (pm s2.seen) |>
|
List.filter (pm s2.seen) |>
|
||||||
List.map fst
|
List.map fst in
|
||||||
and o2 = binops `R s2 |>
|
let o2 = binops `R s2 |>
|
||||||
List.filter (pm s1.seen) |>
|
List.filter (pm s1.seen) |>
|
||||||
List.map fst
|
List.map fst in
|
||||||
in
|
|
||||||
List.map (fun (o, l) ->
|
List.map (fun (o, l) ->
|
||||||
o,
|
o,
|
||||||
{ id = 0
|
{ id = 0
|
||||||
|
@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
|
||||||
type p = string
|
type p = string
|
||||||
|
|
||||||
module StateSet : sig
|
module StateSet : sig
|
||||||
type set
|
type t
|
||||||
val create: unit -> set
|
val create: unit -> t
|
||||||
val add: set -> p state ->
|
val add: t -> p state ->
|
||||||
[> `Added | `Found ] * p state
|
[> `Added | `Found ] * p state
|
||||||
val iter: set -> (p state -> unit) -> unit
|
val iter: t -> (p state -> unit) -> unit
|
||||||
val elems: set -> (p state) list
|
val elems: t -> (p state) list
|
||||||
end = struct
|
end = struct
|
||||||
include Hashtbl.Make(struct
|
open Hashtbl.Make(struct
|
||||||
type t = p state
|
type t = p state
|
||||||
let equal s1 s2 = s1.point = s2.point
|
let equal s1 s2 = s1.point = s2.point
|
||||||
let hash s = Hashtbl.hash s.point
|
let hash s = Hashtbl.hash s.point
|
||||||
end)
|
end)
|
||||||
type set =
|
type nonrec t =
|
||||||
{ h: int t
|
{ h: int t
|
||||||
; mutable next_id: int }
|
; mutable next_id: int }
|
||||||
let create () =
|
let create () =
|
||||||
{ h = create 500; next_id = 1 }
|
{ h = create 500; next_id = 1 }
|
||||||
let add set s =
|
let add set s =
|
||||||
(* delete the check later *)
|
|
||||||
assert (s.point = normalize s.point);
|
assert (s.point = normalize s.point);
|
||||||
try
|
try
|
||||||
let id = find set.h s in
|
let id = find set.h s in
|
||||||
|
@ -171,6 +188,8 @@ end = struct
|
||||||
with Not_found -> begin
|
with Not_found -> begin
|
||||||
let id = set.next_id in
|
let id = set.next_id in
|
||||||
set.next_id <- id + 1;
|
set.next_id <- id + 1;
|
||||||
|
Printf.printf "adding: %d [%s]\n"
|
||||||
|
id (show_pattern s.seen);
|
||||||
add set.h s id;
|
add set.h s id;
|
||||||
`Added, {s with id}
|
`Added, {s with id}
|
||||||
end
|
end
|
||||||
|
@ -198,17 +217,14 @@ end)
|
||||||
type rule =
|
type rule =
|
||||||
{ name: string
|
{ name: string
|
||||||
; pattern: pattern
|
; pattern: pattern
|
||||||
(* TODO access pattern *)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let generate_table rl =
|
let generate_table rl =
|
||||||
let states = StateSet.create () in
|
let states = StateSet.create () in
|
||||||
(* initialize states *)
|
(* initialize states *)
|
||||||
let ground =
|
let ground =
|
||||||
List.fold_left
|
List.concat_map
|
||||||
(fun ini r ->
|
(fun r -> peel r.pattern r.name) rl |>
|
||||||
peel r.pattern r.name @ ini)
|
|
||||||
[] rl |>
|
|
||||||
group_by_fst
|
group_by_fst
|
||||||
in
|
in
|
||||||
let find x d l =
|
let find x d l =
|
||||||
|
@ -242,7 +258,7 @@ let generate_table rl =
|
||||||
flag := `Stop;
|
flag := `Stop;
|
||||||
let statel = StateSet.elems states in
|
let statel = StateSet.elems states in
|
||||||
iter_pairs statel (fun (sl, sr) ->
|
iter_pairs statel (fun (sl, sr) ->
|
||||||
nextbnr tmp sl sr |>
|
next_binary tmp sl sr |>
|
||||||
List.iter (fun (o, s') ->
|
List.iter (fun (o, s') ->
|
||||||
let flag', s' =
|
let flag', s' =
|
||||||
StateSet.add states s' in
|
StateSet.add states s' in
|
||||||
|
|
|
@ -46,54 +46,52 @@ let ts =
|
||||||
}
|
}
|
||||||
|
|
||||||
let print_sm =
|
let print_sm =
|
||||||
let op_str (k, o) =
|
|
||||||
Printf.sprintf "%s%s"
|
|
||||||
(match o with
|
|
||||||
| Oadd -> "add"
|
|
||||||
| Osub -> "sub"
|
|
||||||
| Omul -> "mul")
|
|
||||||
(match k with
|
|
||||||
| Kw -> "w"
|
|
||||||
| Kl -> "l"
|
|
||||||
| Ks -> "s"
|
|
||||||
| Kd -> "d")
|
|
||||||
in
|
|
||||||
StateMap.iter (fun k s' ->
|
StateMap.iter (fun k s' ->
|
||||||
match k with
|
match k with
|
||||||
| K (o, sl, sr) ->
|
| K (o, sl, sr) ->
|
||||||
|
let top =
|
||||||
|
List.fold_left (fun top c ->
|
||||||
|
match c with
|
||||||
|
| Top r -> top ^ " " ^ r
|
||||||
|
| _ -> top) "" s'.point
|
||||||
|
in
|
||||||
Printf.printf
|
Printf.printf
|
||||||
"(%s %d %d) -> %d\n"
|
"(%s %d %d) -> %d%s\n"
|
||||||
(op_str o)
|
(show_op o)
|
||||||
sl.id sr.id s'.id
|
sl.id sr.id s'.id top)
|
||||||
)
|
|
||||||
|
|
||||||
let address_rules =
|
let rules =
|
||||||
let oa = Kl, Oadd in
|
let oa = Kl, Oadd in
|
||||||
let om = Kl, Omul in
|
let om = Kl, Omul in
|
||||||
|
match `X64Addr with
|
||||||
|
(* ------------------------------- *)
|
||||||
|
| `X64Addr ->
|
||||||
let rule name pattern =
|
let rule name pattern =
|
||||||
List.mapi (fun i pattern ->
|
List.mapi (fun i pattern ->
|
||||||
{ name = Printf.sprintf "%s%d" name (i+1)
|
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
|
||||||
; pattern; })
|
; pattern })
|
||||||
(ac_equiv pattern) in
|
(ac_equiv pattern) in
|
||||||
|
|
||||||
(* o + b *)
|
(* o + b *)
|
||||||
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||||
@ (* b + s * i *)
|
@ (* b + s * i *)
|
||||||
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
|
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||||
@ (* o + s * i *)
|
@ (* o + s * i *)
|
||||||
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
|
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||||
@ (* b + o + s * i *)
|
@ (* b + o + s * i *)
|
||||||
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
|
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||||
|
(* ------------------------------- *)
|
||||||
|
| `Add3 ->
|
||||||
|
[ { name = "add"
|
||||||
|
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
|
||||||
|
[ { name = "add"
|
||||||
|
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
|
||||||
|
[ { name = "mul"
|
||||||
|
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
|
||||||
|
Atm Tmp),
|
||||||
|
Bnr (oa, Atm Tmp,
|
||||||
|
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
|
||||||
|
|
||||||
let sl, sm = generate_table address_rules
|
|
||||||
|
let sl, sm = generate_table rules
|
||||||
let s n = List.find (fun {id; _} -> id = n) sl
|
let s n = List.find (fun {id; _} -> id = n) sl
|
||||||
let () = print_sm sm
|
let () = print_sm sm
|
||||||
|
|
||||||
(*
|
|
||||||
let tp0 =
|
|
||||||
let o = Kw, Oadd in
|
|
||||||
Bnr (o, Atm Tmp, Atm (Con 0L))
|
|
||||||
let tp1 =
|
|
||||||
let o = Kw, Oadd in
|
|
||||||
Bnr (o, tp0, Atm (Con 1L))
|
|
||||||
*)
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue