From 041bea32224c90520bbf3ef7572b6397799479ab Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Sun, 19 Mar 2023 10:00:36 +0100
Subject: [PATCH] [interpretation] Add ACASXu test case.

---
 src/interpretation.ml          |   7 +-
 stdlib/interpretation.mlw      |   4 +
 tests/interpretation_acasxu.t  | 256 +++++++++++++++++++++++++++++++++
 tests/interpretation_dataset.t |  25 ++--
 4 files changed, 276 insertions(+), 16 deletions(-)
 create mode 100644 tests/interpretation_acasxu.t

diff --git a/src/interpretation.ml b/src/interpretation.ml
index bcb91a3..0a7abf7 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -127,9 +127,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
         in
         Term (Term.t_tuple [ t_features; t_label ])
       | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
-    | [
-     Term ({ t_node = Tapp _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2);
-    ] ->
+    | [ Term t1; Term t2 ] ->
       (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
       Term (Term.t_app_infer ls [ t1; t2 ])
     | _ -> invalid_arg (error_message ls)
@@ -144,6 +142,9 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
       match caisar_op_of_ls engine ls with
       | Dataset (DS_csv csv) -> Int (BigInt.of_int (Csv.lines csv))
       | Data _ | Classifier _ | Tensor _ | Index _ -> assert false)
+    | [ Term t ] ->
+      (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
+      Term (Term.t_app_infer ls [ t ])
     | _ -> invalid_arg (error_message ls)
   in
 
diff --git a/stdlib/interpretation.mlw b/stdlib/interpretation.mlw
index 1fed08d..48b4998 100644
--- a/stdlib/interpretation.mlw
+++ b/stdlib/interpretation.mlw
@@ -28,7 +28,10 @@ theory Vector
   function ([]) (v: vector 'a) (i: int) : 'a
   function length (v: vector 'a) : int
 
+  predicate has_length (v: vector 'a) (i: int)
+
   function map (v: vector 'a) (f: 'a -> 'b) : vector 'b
+  function mapi (v: vector 'a) (f: int -> 'a -> 'b) : vector 'b
   function map2 (v1: vector 'a) (v2: vector 'b) (f: 'a -> 'b -> 'c) : vector 'c
 
   scope L
@@ -69,6 +72,7 @@ theory Classifier
 
   function read_classifier (f: string) (k: kind) : classifier
   function (@@) (c: classifier) (t: tensor 'a) : vector 'a
+  function (%%) (c: classifier) (v: vector 'a) : vector 'a
 end
 
 theory Dataset
diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t
new file mode 100644
index 0000000..da2effa
--- /dev/null
+++ b/tests/interpretation_acasxu.t
@@ -0,0 +1,256 @@
+Test interpret on acasxu
+  $ caisar interpret -L . --format whyml - 2>&1 <<EOF | ./filter_tmpdir.sh
+  > theory T
+  >   use ieee_float.Float64
+  >   use bool.Bool
+  >   use int.Int
+  >   use interpretation.Vector
+  >   use interpretation.Classifier
+  > 
+  >   constant classifier: classifier = read_classifier "TestNetwork.nnet" NNet
+  > 
+  >   type input = vector t
+  > 
+  >   constant distance_to_intruder: int = 0
+  >   constant angle_to_intruder: int = 1
+  >   constant intruder_heading: int = 2
+  >   constant speed: int = 3
+  >   constant intruder_speed: int = 4
+  > 
+  >   type action = int
+  > 
+  >   constant clear_of_conflict: action = 0
+  >   constant weak_left: action = 1
+  >   constant weak_right: action = 2
+  >   constant strong_left: action = 3
+  >   constant strong_right: action = 4
+  > 
+  >   constant pi: t = 3.141592999999999857863031138549558818340301513671875
+  > 
+  >   predicate valid_input (i: input) =
+  >     (0.0:t) .<= i[distance_to_intruder] .<= (60760.0:t) ->
+  >     .- pi .<= i[angle_to_intruder] .<= pi ->
+  >     .- pi .<= i[intruder_heading] .<= pi ->
+  >     (100.0:t) .<= i[speed] .<= (1200.0:t) ->
+  >     (0.0:t) .<= i[intruder_speed] .<= (1200.0:t)
+  > 
+  >   predicate valid_action (a: action) = 0 <= a <= 4
+  > 
+  >   predicate not_same_action (a1: action) (a2: action) = a1 <> a2
+  > 
+  >   predicate advises (c: classifier) (i: input) (a: action) =
+  >     valid_action a ->
+  >     forall j: action.
+  >       valid_action j ->  not_same_action a j -> (c%%i)[a] .< (c%%i)[j]
+  > 
+  >   predicate intruder_distant_and_slow (i: input) =
+  >     i[distance_to_intruder] .>= (55947.6909999999988940544426441192626953125:t) ->
+  >     i[speed] .>= (1145.0:t) ->
+  >     i[intruder_speed] .<= (60.0:t)
+  > 
+  >   function denormalize_t (i: t) (mean: t) (range: t) : t = (i .* range) .+ mean
+  > 
+  >   function denormalize_input (i:input) : input =
+  >     mapi i (fun idx t -> if idx = distance_to_intruder then denormalize_t t (19791.0:t) (60261.0:t) else t)
+  > 
+  >   function denormalize_output (o: t) : t =
+  >     denormalize_t o (7.51888402010059753166615337249822914600372314453125:t) (373.9499200000000200816430151462554931640625:t)
+  > 
+  >   goal P1:
+  >     forall i: input.
+  >       has_length i 5 ->
+  >       let j = denormalize_input i in
+  >       valid_input j ->
+  >       intruder_distant_and_slow j ->
+  >       let o = (classifier%%i)[clear_of_conflict] in
+  >       (denormalize_output o) .<= (1500.0:t)
+  > 
+  >   predicate directly_ahead (i: input) =
+  >     (1500.0:t) .<= i[distance_to_intruder] .<= (1800.0:t) ->
+  >     .- (0.059999999999999997779553950749686919152736663818359375:t) .<= i[angle_to_intruder] .<= (0.059999999999999997779553950749686919152736663818359375:t)
+  > 
+  >   predicate moving_towards (i: input) =
+  >     i[intruder_heading] .>= (3.100000000000000088817841970012523233890533447265625:t) ->
+  >     i[speed] .>= (900.0:t) ->
+  >     i[intruder_speed] .>= (960.0:t)
+  > 
+  >   goal P2:
+  >     forall i: input.
+  >       has_length i 5 ->
+  >       let j = denormalize_input i in
+  >       valid_input j ->
+  >       directly_ahead j ->
+  >       moving_towards j ->
+  >       not (advises classifier i clear_of_conflict)
+  > end
+  > EOF
+  P1 : forall i:vector t.
+        has_length i 5 ->
+        (le (0.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] (60760.0:t) ->
+         le (neg (3.141592999999999857863031138549558818340301513671875:t))
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] (3.141592999999999857863031138549558818340301513671875:t) ->
+         le (neg (3.141592999999999857863031138549558818340301513671875:t))
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [2] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [2] (3.141592999999999857863031138549558818340301513671875:t) ->
+         le (100.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] (1200.0:t) ->
+         le (0.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4] (1200.0:t)) ->
+        (le (55947.6909999999988940544426441192626953125:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] ->
+         le (1145.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] ->
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4] (60.0:t)) ->
+        le
+        (add RNE
+         (mul RNE (caisar_op %% i)[0]
+          (373.9499200000000200816430151462554931640625:t))
+         (7.51888402010059753166615337249822914600372314453125:t))
+        (1500.0:t)
+  caisar_op,
+  (Interpretation.Classifier
+     "$TESTCASE_ROOT/TestNetwork.nnet")
+  P2 : forall i:vector t.
+        has_length i 5 ->
+        (le (0.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] (60760.0:t) ->
+         le (neg (3.141592999999999857863031138549558818340301513671875:t))
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] (3.141592999999999857863031138549558818340301513671875:t) ->
+         le (neg (3.141592999999999857863031138549558818340301513671875:t))
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [2] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [2] (3.141592999999999857863031138549558818340301513671875:t) ->
+         le (100.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] (1200.0:t) ->
+         le (0.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4] (1200.0:t)) ->
+        (le (1500.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [0] (1800.0:t) ->
+         le (neg (0.059999999999999997779553950749686919152736663818359375:t))
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] /\
+         le
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [1] (0.059999999999999997779553950749686919152736663818359375:t)) ->
+        (le (3.100000000000000088817841970012523233890533447265625:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [2] ->
+         le (900.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [3] ->
+         le (960.0:t)
+         (mapi i
+          (fun (idx:int) (t:t) ->
+            if idx = 0 then denormalize_t t (19791.0:t) (60261.0:t) else t))
+         [4]) ->
+        not (((lt (caisar_op1 %% i)[0] (caisar_op1 %% i)[1] /\
+               lt (caisar_op1 %% i)[0] (caisar_op1 %% i)[2]) /\
+              lt (caisar_op1 %% i)[0] (caisar_op1 %% i)[3]) /\
+             lt (caisar_op1 %% i)[0] (caisar_op1 %% i)[4])
+  caisar_op1,
+  (Interpretation.Classifier
+     "$TESTCASE_ROOT/TestNetwork.nnet")
diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t
index 0b1a8f2..94430b2 100644
--- a/tests/interpretation_dataset.t
+++ b/tests/interpretation_dataset.t
@@ -137,22 +137,21 @@ Test interpret on dataset
         (caisar_op
          @@ caisar_op1 caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4)
         [0])
-  caisar_op2,
+  caisar_op,
+  (Interpretation.Classifier
+     "$TESTCASE_ROOT/TestNetwork.nnet")
+  caisar_op1, (Interpretation.Tensor 5)
+  caisar_op2, (Interpretation.Dataset <csv>)
+  caisar_op3, (Interpretation.Index (Interpretation.I_csv 4))
+  caisar_op4, (Interpretation.Index (Interpretation.I_csv 3))
+  caisar_op5,
   (Interpretation.Data
      (Interpretation.D_csv
         ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"]))
-  caisar_op3, (Interpretation.Index (Interpretation.I_csv 2))
-  caisar_op4, (Interpretation.Index (Interpretation.I_csv 1))
-  caisar_op5, (Interpretation.Index (Interpretation.I_csv 0))
-  caisar_op6,
+  caisar_op6, (Interpretation.Index (Interpretation.I_csv 2))
+  caisar_op7, (Interpretation.Index (Interpretation.I_csv 1))
+  caisar_op8, (Interpretation.Index (Interpretation.I_csv 0))
+  caisar_op9,
   (Interpretation.Data
      (Interpretation.D_csv
         ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))
-  caisar_op,
-  (Interpretation.Classifier
-     "$TESTCASE_ROOT/TestNetwork.nnet")
-  caisar_op1, (Interpretation.Tensor 5)
-  caisar_op7, (Interpretation.Dataset <csv>)
-  caisar_op8, (Interpretation.Index (Interpretation.I_csv 4))
-  caisar_op9,
-  (Interpretation.Index (Interpretation.I_csv 3))
-- 
GitLab