diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml index 17876102066388f89dc8dfe24edaba4655050783..9690498d3cab30fc2f5e5111b6d4721bad794499 100644 --- a/src/interpretation/interpreter_theory.ml +++ b/src/interpretation/interpreter_theory.ml @@ -35,13 +35,17 @@ let fail_on_unexpected_argument ls = module Vector = struct let (get : _ IRE.builtin) = fun engine ls vl ty -> + let th_model = Symbols.Model.create (IRE.user_env engine).ITypes.env in match vl with | [ Term - ({ t_node = Tapp (_ (* @@ *), [ { t_node = Tapp (ls, _); _ }; _ ]); _ } - as _t1); + ({ + t_node = Tapp (ls_atat (* @@ *), [ { t_node = Tapp (ls, _); _ }; _ ]); + _; + } as _t1); Term ({ t_node = Tconst (ConstInt i); _ } as t2); - ] -> ( + ] + when Why3.Term.ls_equal ls_atat th_model.atat -> ( let i = Why3.Number.to_small_integer i in if i < 0 then @@ -304,50 +308,60 @@ module NN = struct | _ -> fail_on_unexpected_argument ls let apply : _ IRE.builtin = - fun engine ls vl _ty -> + fun engine ls vl ty -> match vl with - | [ - Term ({ t_node = Tapp (ls1, []); _ } as t1); - Term ({ t_node = Tapp (ls2, tl2); _ } as t2); - ] -> ( - match (ITypes.op_of_ls engine ls1, ITypes.op_of_ls engine ls2) with - | Model (NN (nn, _)), Vector v -> - let nn = - match Language.lookup_nn nn with - | None -> - Logging.code_error ~src:Logging.src_interpret_goal (fun m -> - m "Cannot find neural network model from lsymbol %a" - Why3.Pretty.print_ls nn) - | Some nn -> nn - in - let length_v = - match Language.lookup_vector v with - | None -> - Logging.code_error ~src:Logging.src_interpret_goal (fun m -> - m "Cannot find vector from lsymbol %a" Why3.Pretty.print_ls v) - | Some n -> - if List.length tl2 <> n - then + | [ Term t1; Term t2 ] -> ( + match ITypes.op_of_term engine t1 with + | Some (Model (NN (nn, _)), []) -> + let { ITypes.env; _ } = IRE.user_env engine in + let nn = Option.value_exn (Language.lookup_nn nn) in + (match ITypes.op_of_term engine t2 with + | Some (Vector v, tl2) -> + let length_v = + match Language.lookup_vector v with + | None -> Logging.code_error ~src:Logging.src_interpret_goal (fun m -> - m - "Mismatch between (container) vector length and number of \ - (contained) input variables."); - n + m "Cannot find vector from lsymbol %a" Why3.Pretty.print_ls v) + | Some n -> + if List.length tl2 <> n + then + Logging.code_error ~src:Logging.src_interpret_goal (fun m -> + m + "Mismatch between (container) vector length and number of \ + (contained) input variables."); + n + in + if nn.nn_nb_inputs <> length_v + then + Logging.user_error ?loc:t2.t_loc (fun m -> + m + "Unexpected vector of length %d in input to neural network \ + model '%s',@ which expects input vectors of length %d" + length_v nn.nn_filename nn.nn_nb_inputs) + | _ -> + Logging.user_error ?loc:t1.t_loc (fun m -> + m "Unexpected neural network model application: %a" + Why3.Pretty.print_term t2)); + let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in + let get = + Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ] in - if nn.nn_nb_inputs <> length_v - then - Logging.user_error ?loc:t2.t_loc (fun m -> - m - "Unexpected vector of length %d in input to neural network model \ - '%s',@ which expects input vectors of length %d" - length_v nn.nn_filename nn.nn_nb_inputs) - else IRE.reconstruct_term () - | Model (SVM _), _ -> - (* Should be already catched by the Why3 typing. *) - assert false - | _, _ -> - Logging.user_error ?loc:t1.t_loc (fun m -> - m "Unexpected neural network model application")) + let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in + let args = + List.init nn.nn_nb_outputs ~f:(fun i -> + ( Why3.Term.fs_app get + [ + t0; + Why3.Term.t_const + (Why3.Constant.int_const_of_int i) + Why3.Ty.ty_int; + ] + nn.nn_ty_elt, + nn.nn_ty_elt )) + in + let op = ITypes.Vector (Language.create_vector env nn.nn_nb_outputs) in + IRE.value_term (ITypes.term_of_op ~args engine op ty) + | _ -> IRE.reconstruct_term ()) | _ -> fail_on_unexpected_argument ls let builtins : _ IRE.built_in_theories = diff --git a/src/interpretation/interpreter_types.ml b/src/interpretation/interpreter_types.ml index d538870d9396179fb93f2178445c09437041b2bc..0e015b975d69ac2353dcefdeeea0b23935c10745 100644 --- a/src/interpretation/interpreter_types.ml +++ b/src/interpretation/interpreter_types.ml @@ -103,6 +103,14 @@ let term_of_op ?(args = []) engine interpreter_op ty = let t_args, ty_args = List.unzip args in Why3.Term.t_app_infer (ls_of_op engine interpreter_op ty_args ty) t_args +let op_of_term engine t = + match t.Why3.Term.t_node with + | Tapp (ls, args) -> ( + match op_of_ls engine ls with + | exception Stdlib.Not_found -> None + | v -> Some (v, args)) + | _ -> None + let interpreter_env ~cwd env = { ls_of_op = Hashtbl.Poly.create (); diff --git a/src/interpretation/interpreter_types.mli b/src/interpretation/interpreter_types.mli index 8eb1ce80f2233dfbc504662229ca5712c459f90b..2481c482ef4377a5bbbc43b2f0367729020e3174 100644 --- a/src/interpretation/interpreter_types.mli +++ b/src/interpretation/interpreter_types.mli @@ -54,6 +54,8 @@ type interpreter_env = private { val op_of_ls : interpreter_env IRE.engine -> Why3.Term.lsymbol -> interpreter_op +val op_of_term : interpreter_env IRE.engine -> Why3.Term.term -> (interpreter_op * Why3.Term.term list) option + val term_of_op : ?args:(Why3.Term.term * Why3.Ty.ty) list -> interpreter_env IRE.engine -> diff --git a/tests/interpretation_fail.t b/tests/interpretation_fail.t index a389622e32687f4c40c1c8504faa887c96b42f54..7d68cd97bec033c17d5662d04bd9df544f58f365 100644 --- a/tests/interpretation_fail.t +++ b/tests/interpretation_fail.t @@ -146,7 +146,7 @@ Test interpret fail > EOF $ caisar verify --prover nnenum file.mlw - [ERROR] "file.mlw", line 12, characters 24-26: + [ERROR] "file.mlw", line 12, characters 14-23: Index constant 10 is out-of-bounds [0,4] $ cat > file.mlw <<EOF @@ -183,5 +183,4 @@ Test interpret fail > EOF $ caisar verify --prover SAVer file.mlw - [ERROR] "file.mlw", line 10, characters 35-36: - Index constant 4 is out-of-bounds [0,1] + [ERROR] Cannot find feature for input variable 'x'