From 4e558db3d07ffc6d2aec8b57463ec7145aac123d Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 5 Apr 2023 16:30:02 +0200
Subject: [PATCH] [interpretation] Add notion of classifier in language and use
 it in interpretation.

---
 src/interpretation.ml          | 42 +++++++++++----
 src/language.ml                | 95 +++++++++++++++++++++++++++++++---
 src/language.mli               | 21 +++++++-
 tests/interpretation_acasxu.t  | 16 +++---
 tests/interpretation_dataset.t | 14 +++--
 5 files changed, 160 insertions(+), 28 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index 238359d..d812197 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -24,16 +24,32 @@ module CRE = Reduction_engine (* Caisar Reduction Engine *)
 open Why3
 open Base
 
+type classifier =
+  | NNet of Language.nn_classifier
+      [@printer
+        fun fmt nn ->
+          Fmt.pf fmt "NNet: %a"
+            Fmt.(option Language.pp_nn)
+            (Language.lookup_nn_classifier nn)]
+  | ONNX of Language.nn_classifier
+      [@printer
+        fun fmt nn ->
+          Fmt.pf fmt "ONNX: %a"
+            Fmt.(option Language.pp_nn)
+            (Language.lookup_nn_classifier nn)]
+[@@deriving show]
+
 type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
 [@@deriving show]
 
-type classifier = string [@@deriving show]
 type data = D_csv of string list [@@deriving show]
 type index = I_csv of int [@@deriving show]
 
 type vector =
   (Language.vector
-  [@printer fun fmt v -> Fmt.pf fmt "%d" (Language.lookup_vector v)])
+  [@printer
+    fun fmt v ->
+      Fmt.pf fmt "%a" Fmt.(option ~none:nop int) (Language.lookup_vector v)])
 [@@deriving show]
 
 type caisar_op =
@@ -139,7 +155,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
         in
         term (Term.t_tuple [ t_features; t_label ])
       | Vector v ->
-        let n = Language.lookup_vector v in
+        let n = Option.value_exn (Language.lookup_vector v) in
         assert (List.length tl1 = n && i <= n);
         term (List.nth_exn tl1 i)
       | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
@@ -157,13 +173,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
     | [ Term { t_node = Tapp (ls, []); _ } ] -> (
       match caisar_op_of_ls engine ls with
       | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv))
-      | Vector v -> int (BigInt.of_int (Language.lookup_vector v))
+      | Vector v ->
+        int (BigInt.of_int (Option.value_exn (Language.lookup_vector v)))
       | Data (D_csv data) -> int (BigInt.of_int (List.length data))
       | Classifier _ | Tensor _ | Index _ -> assert false)
     | [ Term { t_node = Tapp (ls, tl); _ } ] -> (
       match caisar_op_of_ls engine ls with
       | Vector v ->
-        let n = Language.lookup_vector v in
+        let n = Option.value_exn (Language.lookup_vector v) in
         assert (List.length tl = n);
         int (BigInt.of_int n)
       | Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
@@ -185,7 +202,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
       match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with
       | Vector v, Data (D_csv data) ->
-        let n = Language.lookup_vector v in
+        let n = Option.value_exn (Language.lookup_vector v) in
         assert (n = List.length data);
         let ty_cst =
           match ty with
@@ -229,7 +246,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
       match caisar_op_of_ls engine ls1 with
       | Vector v ->
-        let n = Language.lookup_vector v in
+        let n = Option.value_exn (Language.lookup_vector v) in
         assert (List.length tl1 = n);
         let args =
           List.mapi tl1 ~f:(fun idx t ->
@@ -322,12 +339,17 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
     match vl with
     | [
      Term { t_node = Tconst (ConstStr classifier); _ };
-     Term { t_node = Tapp ({ ls_name = { id_string = "NNet"; _ }; _ }, []); _ };
+     Term { t_node = Tapp ({ ls_name = { id_string; _ }; _ }, []); _ };
     ] ->
-      let cwd = (CRE.user_env engine).cwd in
+      let { env; cwd; _ } = CRE.user_env engine in
       let caisar_op =
         let filename = Caml.Filename.concat cwd classifier in
-        Classifier filename
+        let classifier =
+          if String.equal id_string "NNet"
+          then NNet (Language.create_nnet_classifier env filename)
+          else ONNX (Language.create_onnx_classifier env filename)
+        in
+        Classifier classifier
       in
       term (term_of_caisar_op engine caisar_op ty)
     | _ -> invalid_arg (error_message ls)
diff --git a/src/language.ml b/src/language.ml
index 8899097..3499f00 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -156,27 +156,110 @@ let register_ovo_support () =
   Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
     (fun env _ filename _ -> ovo_parser env filename)
 
+(* -- Vector *)
+
 type vector = Term.lsymbol
 
 let vectors = Term.Hls.create 10
 
+let vector_elt_ty env =
+  let th = Env.read_theory env [ "ieee_float" ] "Float64" in
+  Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
+
 let create_vector =
   Env.Wenv.memoize 13 (fun env ->
     let h = Hashtbl.create (module Int) in
-    let ty_elt =
-      let th = Env.read_theory env [ "ieee_float" ] "Float64" in
-      Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
-    in
+    let ty_elt = vector_elt_ty env in
     let ty =
       let th = Env.read_theory env [ "interpretation" ] "Vector" in
       Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ]
     in
     Hashtbl.findi_or_add h ~default:(fun length ->
-      let id = Ident.id_fresh "vector" in
       let ls =
+        let id = Ident.id_fresh "vector" in
         Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty
       in
       Term.Hls.add vectors ls length;
       ls))
 
-let lookup_vector = Term.Hls.find vectors
+let lookup_vector = Term.Hls.find_opt vectors
+
+(* -- Classifier *)
+
+type nn = {
+  nn_inputs : int;
+  nn_outputs : int;
+  nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
+  nn_filename : string;
+  nn_nier : Onnx.G.t option; [@opaque]
+}
+[@@deriving show]
+
+type nn_classifier = Term.lsymbol
+
+let nn_classifiers = Term.Hls.create 10
+
+let fresh_classifier_ls env name =
+  let ty =
+    let th = Env.read_theory env [ "interpretation" ] "Classifier" in
+    Ty.ty_app (Theory.ns_find_ts th.th_export [ "classifier" ]) []
+  in
+  let id = Ident.id_fresh name in
+  Term.create_fsymbol id [] ty
+
+let create_nnet_classifier =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module String) in
+    let ty_elt =
+      let th = Env.read_theory env [ "ieee_float" ] "Float64" in
+      Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
+    in
+    Hashtbl.findi_or_add h ~default:(fun filename ->
+      let ls = fresh_classifier_ls env "nnet_classifier" in
+      let nn =
+        let model = Nnet.parse ~permissive:true filename in
+        match model with
+        | Error s -> Loc.errorm "%s" s
+        | Ok { n_inputs; n_outputs; _ } ->
+          {
+            nn_inputs = n_inputs;
+            nn_outputs = n_outputs;
+            nn_ty_elt = ty_elt;
+            nn_filename = filename;
+            nn_nier = None;
+          }
+      in
+      Term.Hls.add nn_classifiers ls nn;
+      ls))
+
+let create_onnx_classifier =
+  Env.Wenv.memoize 13 (fun env ->
+    let h = Hashtbl.create (module String) in
+    let ty_elt = vector_elt_ty env in
+    Hashtbl.findi_or_add h ~default:(fun filename ->
+      let ls = fresh_classifier_ls env "onnx_classifier" in
+      let onnx =
+        let model = Onnx.parse filename in
+        match model with
+        | Error s -> Loc.errorm "%s" s
+        | Ok { n_inputs; n_outputs; nier } ->
+          let nier =
+            match nier with
+            | Error msg ->
+              Logs.warn (fun m ->
+                m "Cannot build network intermediate representation:@ %s" msg);
+              None
+            | Ok nier -> Some nier
+          in
+          {
+            nn_inputs = n_inputs;
+            nn_outputs = n_outputs;
+            nn_ty_elt = ty_elt;
+            nn_filename = filename;
+            nn_nier = nier;
+          }
+      in
+      Term.Hls.add nn_classifiers ls onnx;
+      ls))
+
+let lookup_nn_classifier = Term.Hls.find_opt nn_classifiers
diff --git a/src/language.mli b/src/language.mli
index a361fd4..afa4a11 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -63,7 +63,26 @@ val ovo_parser : Env.env -> string -> Pmodule.pmodule Wstdlib.Mstr.t
 (* [ovo_parser env filename] parses and creates the theories corresponding to
    the given ovo [filename]. The result is memoized. *)
 
+(** -- Vector *)
+
 type vector = Term.lsymbol
 
 val create_vector : Env.env -> int -> vector
-val lookup_vector : vector -> int
+val lookup_vector : vector -> int option
+
+(** -- Classifier *)
+
+type nn = private {
+  nn_inputs : int;
+  nn_outputs : int;
+  nn_ty_elt : Ty.ty;
+  nn_filename : string;
+  nn_nier : Onnx.G.t option;
+}
+[@@deriving show]
+
+type nn_classifier = Term.lsymbol
+
+val create_nnet_classifier : Env.env -> string -> nn_classifier
+val create_onnx_classifier : Env.env -> string -> nn_classifier
+val lookup_nn_classifier : nn_classifier -> nn option
diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t
index 253cde0..9d1adea 100644
--- a/tests/interpretation_acasxu.t
+++ b/tests/interpretation_acasxu.t
@@ -130,11 +130,13 @@ Test interpret on acasxu
           [0] (373.9499200000000200816430151462554931640625:t))
          (7.51888402010059753166615337249822914600372314453125:t))
         (1500.0:t)
+  vector, (Interpretation.Vector 5)
   caisar_op,
   (Interpretation.Classifier
-     "$TESTCASE_ROOT/TestNetwork.nnet")
-  vector,
-  (Interpretation.Vector 5)
+     NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
+             nn_filename =
+             "$TESTCASE_ROOT/TestNetwork.nnet";
+             nn_nier = <opaque> })
   P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
         (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\
          le (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) (60760.0:t) ->
@@ -212,8 +214,10 @@ Test interpret on acasxu
              (caisar_op1
               %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
              [4])
+  vector, (Interpretation.Vector 5)
   caisar_op1,
   (Interpretation.Classifier
-     "$TESTCASE_ROOT/TestNetwork.nnet")
-  vector,
-  (Interpretation.Vector 5)
+     NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
+             nn_filename =
+             "$TESTCASE_ROOT/TestNetwork.nnet";
+             nn_nier = <opaque> })
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index 2fd6c14..3ff6693 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -129,15 +129,19 @@ Test interpret on dataset
         (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0])
   caisar_op1,
+  (Interpretation.Data
+     (Interpretation.D_csv
+        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
+  caisar_op2,
   (Interpretation.Data
      (Interpretation.D_csv
         ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
   vector, (Interpretation.Vector 5)
   caisar_op,
   (Interpretation.Classifier
-     "$TESTCASE_ROOT/TestNetwork.nnet")
-  caisar_op2, (Interpretation.Dataset <csv>)
+     NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
+             nn_filename =
+             "$TESTCASE_ROOT/TestNetwork.nnet";
+             nn_nier = <opaque> })
   caisar_op3,
-  (Interpretation.Data
-     (Interpretation.D_csv
-        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
+  (Interpretation.Dataset <csv>)
-- 
GitLab