From d01b778a328371086373e04b71454bfd5c9d4a89 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Tue, 23 Apr 2024 11:58:34 +0200
Subject: [PATCH] [nn_native] convert vector term to node with shape of
 arbitrary size

---
 src/interpretation/interpreter_theory.ml |  41 +++---
 src/transformations/native_nn_prover.ml  | 172 +++++++++++------------
 tests/acasxu_ci.t                        |   6 +-
 tests/interpretation_fail.t              |   2 +-
 4 files changed, 112 insertions(+), 109 deletions(-)

diff --git a/src/interpretation/interpreter_theory.ml b/src/interpretation/interpreter_theory.ml
index c235c43..6f079ff 100644
--- a/src/interpretation/interpreter_theory.ml
+++ b/src/interpretation/interpreter_theory.ml
@@ -385,25 +385,28 @@ module Model = struct
           ()
           (* Logging.user_error ?loc:t1.t_loc (fun m -> m "Unexpected neural
              network model application: %a" Why3.Pretty.print_term t2) *));
-        let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in
-        let get =
-          Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ]
-        in
-        let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in
-        let args =
-          List.init nb_outputs ~f:(fun i ->
-            ( Why3.Term.fs_app get
-                [
-                  t0;
-                  Why3.Term.t_const
-                    (Why3.Constant.int_const_of_int i)
-                    Why3.Ty.ty_int;
-                ]
-                ty_elt,
-              ty_elt ))
-        in
-        let op = ITypes.Vector (Language.create_vector env nb_outputs) in
-        IRE.value_term (ITypes.term_of_op ~args engine op ty)
+        if false
+        then
+          let th = Why3.Env.read_theory env [ "caisar"; "types" ] "Vector" in
+          let get =
+            Why3.Theory.ns_find_ls th.th_export [ Why3.Ident.op_get "" ]
+          in
+          let t0 = Why3.Term.t_app ls [ t1; t2 ] ty in
+          let args =
+            List.init nb_outputs ~f:(fun i ->
+              ( Why3.Term.fs_app get
+                  [
+                    t0;
+                    Why3.Term.t_const
+                      (Why3.Constant.int_const_of_int i)
+                      Why3.Ty.ty_int;
+                  ]
+                  ty_elt,
+                ty_elt ))
+          in
+          let op = ITypes.Vector (Language.create_vector env nb_outputs) in
+          IRE.value_term (ITypes.term_of_op ~args engine op ty)
+        else IRE.reconstruct_term ()
       | _ -> IRE.reconstruct_term ())
     | _ -> fail_on_unexpected_argument ls
 
diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml
index e1c698d..f3c364b 100644
--- a/src/transformations/native_nn_prover.ml
+++ b/src/transformations/native_nn_prover.ml
@@ -72,10 +72,7 @@ let match_nn_app th_model term =
       assert (nn_nb_inputs = vector_length && vector_length = List.length tl);
       let c = Why3.Number.to_small_integer c in
       Some (nn, tl, c)
-    | _, _ ->
-      Logging.code_error ~src (fun m ->
-        m "Neural network application without fixed NN or arguments: %a"
-          Why3.Pretty.print_term term))
+    | _, _ -> None)
   | _ -> None
 
 let create_new_nn env input_vars outputs : string =
@@ -93,22 +90,14 @@ let create_new_nn env input_vars outputs : string =
       IR.Node.gather_int input i)
   in
   let cache = Why3.Term.Hterm.create 17 in
-  let nn_cache = Stdlib.Hashtbl.create 17 in
-  (* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
-     into nodes. *)
-  let rec convert_old_nn old_nn old_nn_args =
-    let converted_args = List.map ~f:convert_term old_nn_args in
-    let id =
-      ( old_nn.Language.nn_filename,
-        List.map converted_args ~f:(fun n -> n.Nir.Node.id) )
-    in
-    match Stdlib.Hashtbl.find_opt nn_cache id with
-    | None ->
-      let node_nn = convert_old_nn_aux old_nn converted_args in
-      Stdlib.Hashtbl.add nn_cache id node_nn;
-      node_nn
-    | Some node_nn -> node_nn
-  and convert_old_nn_aux old_nn converted_args =
+  (* Instantiate the input of [old_nn] with the [converted_arg] term transformed
+     into a node. *)
+  let convert_nn ?loc old_nn converted_arg =
+    if old_nn.Language.nn_nb_inputs
+       <> Nir.Shape.size converted_arg.Nir.Node.shape
+    then
+      Logging.user_error ?loc (fun m ->
+        m "Neural network applied with the wrong number of arguments");
     let old_nn_nir =
       match Onnx.Reader.from_file old_nn.Language.nn_filename with
       | Error s ->
@@ -121,76 +110,88 @@ let create_new_nn env input_vars outputs : string =
     in
     (* Create the graph to replace the old input of the nn *)
     let input () =
-      (* Regroup the terms into one node *)
-      let node =
-        IR.Node.create (Concat { inputs = converted_args; axis = 0 })
-      in
-      IR.Node.reshape (IR.Ngraph.input_shape old_nn_nir) node
+      let node = converted_arg in
+      Nir.Node.reshape (Nir.Ngraph.input_shape old_nn_nir) node
     in
     let out =
-      IR.Node.replace_input input (IR.Ngraph.output old_nn_nir)
+      IR.Node.replace_input input (Nir.Ngraph.output old_nn_nir)
       |> IR.Node.reshape (Nir.Shape.of_array [| old_nn.nn_nb_outputs |])
     in
     out
-  and convert_old_nn_at_old_index old_nn old_index old_nn_args =
-    let out = convert_old_nn old_nn old_nn_args in
-    Nir.Node.gather_int out old_index
-  and convert_term term =
+  in
+  let rec convert_term term =
     match Why3.Term.Hterm.find_opt cache term with
     | None ->
       let n = convert_term_aux term in
       Why3.Term.Hterm.add cache term n;
       n
     | Some n -> n
-  and convert_term_aux term : IR.Node.t =
-    if not (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
+  and convert_term_aux term : Nir.Node.t =
+    let loc = term.Why3.Term.t_loc in
+    if false
+       && not
+            (Why3.Ty.ty_equal (Option.value_exn term.Why3.Term.t_ty) th_f64.ty)
     then
       Logging.user_error ?loc:term.t_loc (fun m ->
         m "Cannot convert non Float64 term %a" Why3.Pretty.print_term term);
-    match match_nn_app th_model term with
-    | Some (old_nn, tl, old_index) ->
-      convert_old_nn_at_old_index old_nn old_index tl
-    | None -> (
-      match term.Why3.Term.t_node with
+    match term.Why3.Term.t_node with
+    | Tapp (ls_get (* [ ] *), [ t1; { t_node = Tconst (ConstInt c); _ } ])
+      when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") ->
+      let c = Why3.Number.to_small_integer c in
+      let t1 = convert_term t1 in
+      if Nir.Shape.size t1.Nir.Node.shape <= c
+      then
+        Logging.user_error ?loc (fun m ->
+          m "Vector accessed outside its bounds");
+      Nir.Node.gather_int t1 c
+    | Tapp (ls_atat (* @@ *), [ { t_node = Tapp (ls_nn (* nn *), _); _ }; t2 ])
+      when Why3.Term.ls_equal ls_atat th_model.Symbols.Model.atat -> (
+      match Language.lookup_nn ls_nn with
+      | Some nn -> convert_nn ?loc nn (convert_term t2)
+      | _ ->
+        Logging.code_error ~src (fun m ->
+          m "Neural network application without fixed NN: %a"
+            Why3.Pretty.print_term term))
+    | Tapp (ls (* input vector *), tl (* input vars *))
+      when Language.mem_vector ls ->
+      let inputs = List.map tl ~f:convert_term in
+      IR.Node.create (Concat { inputs; axis = 0 })
+    | Tconst (Why3.Constant.ConstReal r) ->
+      IR.Node.create
+        (Constant
+           {
+             data =
+               Nir.Gentensor.create_1_float (Utils.float_of_real_constant r);
+           })
+    | Tapp (ls, []) -> get_input ls
+    | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add ->
+      IR.Node.create (Add { input1 = convert_term a; input2 = convert_term b })
+    | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub ->
+      IR.Node.create (Sub { input1 = convert_term a; input2 = convert_term b })
+    | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul ->
+      IR.Node.create (Mul { input1 = convert_term a; input2 = convert_term b })
+    | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> (
+      match b.t_node with
       | Tconst (Why3.Constant.ConstReal r) ->
+        let f = Utils.float_of_real_constant r in
+        Nir.Node.div_float (convert_term a) f
+      | _ ->
         IR.Node.create
-          (Constant
-             {
-               data =
-                 IR.Gentensor.create_1_float (Utils.float_of_real_constant r);
-             })
-      | Tapp (ls, []) -> get_input ls
-      | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.add ->
-        IR.Node.create
-          (Add { input1 = convert_term a; input2 = convert_term b })
-      | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.sub ->
-        IR.Node.create
-          (Sub { input1 = convert_term a; input2 = convert_term b })
-      | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.mul ->
-        IR.Node.create
-          (Mul { input1 = convert_term a; input2 = convert_term b })
-      | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> (
-        match b.t_node with
-        | Tconst (Why3.Constant.ConstReal r) ->
-          let f = Utils.float_of_real_constant r in
-          Nir.Node.div_float (convert_term a) f
-        | _ ->
-          IR.Node.create
-            (Div { input1 = convert_term a; input2 = convert_term b }))
-      | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg ->
-        Nir.Node.mul_float (convert_term a) (-1.)
-      | Tconst _
-      | Tapp (_, _)
-      | Tif (_, _, _)
-      | Tlet (_, _)
-      | Tbinop (_, _, _)
-      | Tcase (_, _)
-      | Tnot _ | Ttrue | Tfalse ->
-        Logging.not_implemented_yet (fun m ->
-          m "Why3 term to IR: %a" Why3.Pretty.print_term term)
-      | Tvar _ | Teps _ | Tquant (_, _) ->
-        Logging.not_implemented_yet (fun m ->
-          m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term))
+          (Div { input1 = convert_term a; input2 = convert_term b }))
+    | Tapp (ls, [ a ]) when Why3.Term.ls_equal ls th_f64.neg ->
+      Nir.Node.mul_float (convert_term a) (-1.)
+    | Tconst _
+    | Tapp (_, _)
+    | Tif (_, _, _)
+    | Tlet (_, _)
+    | Tbinop (_, _, _)
+    | Tcase (_, _)
+    | Tnot _ | Ttrue | Tfalse ->
+      Logging.not_implemented_yet (fun m ->
+        m "Why3 term to IR: %a" Why3.Pretty.print_term term)
+    | Tvar _ | Teps _ | Tquant (_, _) ->
+      Logging.not_implemented_yet (fun m ->
+        m "Why3 term to IR (impossible?): %a" Why3.Pretty.print_term term)
   in
   (* Start actual term conversion. *)
   let outputs =
@@ -259,19 +260,18 @@ let check_if_new_nn_needed env input_vars outputs =
 
 (* Choose the term pattern for starting the conversion to ONNX. *)
 let has_start_pattern env term =
-  let th_model = Symbols.Model.create env in
   let th_f = Symbols.Float64.create env in
-  match match_nn_app th_model term with
-  | Some _ -> true
-  | None -> (
-    match term.Why3.Term.t_node with
-    | Tapp ({ ls_value = Some ty; _ }, []) ->
-      (* input symbol *) Why3.Ty.ty_equal ty th_f.ty
-    | Tapp (ls_app, _) ->
-      List.mem
-        [ th_f.add; th_f.sub; th_f.mul; th_f.div ]
-        ~equal:Why3.Term.ls_equal ls_app
-    | _ -> false)
+  match term.Why3.Term.t_node with
+  | Tapp (ls_get (* [ ] *), _)
+    when String.equal ls_get.ls_name.id_string (Why3.Ident.op_get "") ->
+    true
+  | Tapp ({ ls_value = Some ty; _ }, []) ->
+    (* input symbol *) Why3.Ty.ty_equal ty th_f.ty
+  | Tapp (ls_app, _) ->
+    List.mem
+      [ th_f.add; th_f.sub; th_f.mul; th_f.div ]
+      ~equal:Why3.Term.ls_equal ls_app
+  | _ -> false
 
 (* Abstract terms that contains neural network applications. *)
 let abstract_nn_term env =
diff --git a/tests/acasxu_ci.t b/tests/acasxu_ci.t
index 72468ad..f274c7e 100644
--- a/tests/acasxu_ci.t
+++ b/tests/acasxu_ci.t
@@ -936,9 +936,9 @@ Test verify on acasxu
   out/caisar_1.onnx has 1 input nodes
   {'name': '221', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   out/caisar_2.onnx has 1 input nodes
-  {'name': '440', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
+  {'name': '439', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   out/caisar_3.onnx has 1 input nodes
-  {'name': '588', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
+  {'name': '587', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '6'}]}}}}
   out/caisar_4.onnx has 1 input nodes
-  {'name': '815', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
+  {'name': '814', 'type': {'tensorType': {'elemType': 1, 'shape': {'dim': [{'dimValue': '5'}]}}}}
   5 files checked
diff --git a/tests/interpretation_fail.t b/tests/interpretation_fail.t
index 72ca477..bd06a54 100644
--- a/tests/interpretation_fail.t
+++ b/tests/interpretation_fail.t
@@ -146,7 +146,7 @@ Test interpret fail
   > EOF
 
   $ caisar verify --prover nnenum file.mlw
-  [ERROR] "file.mlw", line 12, characters 14-23:
+  [ERROR] "file.mlw", line 12, characters 24-26:
           Index constant 10 is out-of-bounds [0,4]
 
   $ cat > file.mlw <<EOF
-- 
GitLab