From ec002e5fe67bbad5c3c60f8b85891213e111382e Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Tue, 11 Apr 2023 13:14:39 +0200
Subject: [PATCH] [interpretation] Use classifier logic symbols created by
 language module.

---
 src/interpretation.ml          |  5 ++++-
 tests/interpretation_acasxu.t  | 29 +++++++++++++------------
 tests/interpretation_dataset.t | 39 ++++++++++++++++++++--------------
 3 files changed, 43 insertions(+), 30 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index b5d97f8..2fa1b82 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -75,7 +75,10 @@ let ls_of_caisar_op engine op ty_args ty =
   Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () ->
     let id = Ident.id_fresh "caisar_op" in
     let ls =
-      match op with Vector v -> v | _ -> Term.create_lsymbol id ty_args ty
+      match op with
+      | Classifier (NNet c | ONNX c) -> c
+      | Vector v -> v
+      | _ -> Term.create_lsymbol id ty_args ty
     in
     (* Fmt.pr "ls: %a@." Pretty.print_ls ls; *)
     Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls;
diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t
index 9d1adea..93fae2a 100644
--- a/tests/interpretation_acasxu.t
+++ b/tests/interpretation_acasxu.t
@@ -126,17 +126,19 @@ Test interpret on acasxu
         le
         (add RNE
          (mul RNE
-          (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+          (nnet_classifier
+           %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
           [0] (373.9499200000000200816430151462554931640625:t))
          (7.51888402010059753166615337249822914600372314453125:t))
         (1500.0:t)
-  vector, (Interpretation.Vector 5)
-  caisar_op,
+  nnet_classifier,
   (Interpretation.Classifier
      NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
              nn_filename =
              "$TESTCASE_ROOT/TestNetwork.nnet";
              nn_nier = <opaque> })
+  vector,
+  (Interpretation.Vector 5)
   P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
         (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\
          le (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) (60760.0:t) ->
@@ -187,37 +189,38 @@ Test interpret on acasxu
          le (900.0:t) (add RNE (mul RNE caisar_v3 (1100.0:t)) (650.0:t)) ->
          le (960.0:t) (add RNE (mul RNE caisar_v4 (1200.0:t)) (600.0:t))) ->
         not (((lt
-               (caisar_op1
+               (nnet_classifier
                 %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [0]
-               (caisar_op1
+               (nnet_classifier
                 %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [1] /\
                lt
-               (caisar_op1
+               (nnet_classifier
                 %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [0]
-               (caisar_op1
+               (nnet_classifier
                 %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
                [2]) /\
               lt
-              (caisar_op1
+              (nnet_classifier
                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
               [0]
-              (caisar_op1
+              (nnet_classifier
                %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
               [3]) /\
              lt
-             (caisar_op1
+             (nnet_classifier
               %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
              [0]
-             (caisar_op1
+             (nnet_classifier
               %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
              [4])
-  vector, (Interpretation.Vector 5)
-  caisar_op1,
+  nnet_classifier,
   (Interpretation.Classifier
      NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
              nn_filename =
              "$TESTCASE_ROOT/TestNetwork.nnet";
              nn_nier = <opaque> })
+  vector,
+  (Interpretation.Vector 5)
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index 3ff6693..de5f086 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -78,14 +78,18 @@ Test interpret on dataset
          (0.776470588000000017103729987866245210170745849609375:t))
         (0.375:t) ->
         lt
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0]
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1] /\
         lt
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [2]
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1]) /\
       (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t.
         ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\
@@ -119,29 +123,32 @@ Test interpret on dataset
          (0.78431372499999996161790249971090815961360931396484375:t))
         (0.375:t) ->
         lt
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [1]
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0] /\
         lt
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [2]
-        (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
+        (nnet_classifier
+         %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0])
-  caisar_op1,
-  (Interpretation.Data
-     (Interpretation.D_csv
-        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-  caisar_op2,
+  caisar_op,
   (Interpretation.Data
      (Interpretation.D_csv
         ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
   vector, (Interpretation.Vector 5)
-  caisar_op,
+  nnet_classifier,
   (Interpretation.Classifier
      NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t;
              nn_filename =
              "$TESTCASE_ROOT/TestNetwork.nnet";
              nn_nier = <opaque> })
-  caisar_op3,
-  (Interpretation.Dataset <csv>)
+  caisar_op1, (Interpretation.Dataset <csv>)
+  caisar_op2,
+  (Interpretation.Data
+     (Interpretation.D_csv
+        ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-- 
GitLab