From 4ad94e14c4db88305459236caa05a81b54370dee Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 20 Apr 2023 16:02:34 +0200
Subject: [PATCH] [interpretation] Declare only logic symbols relative to
 classifiers and vectors.

---
 src/interpretation.ml | 9 ++++++---
 src/language.ml       | 2 ++
 src/language.mli      | 2 ++
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/src/interpretation.ml b/src/interpretation.ml
index fdbea63..6fa6794 100644
--- a/src/interpretation.ml
+++ b/src/interpretation.ml
@@ -526,11 +526,14 @@ let interpret_task ~cwd env task =
   let f = CRE.normalize ~limit:Int.max_value engine Term.Mvs.empty f in
   let _, task = Task.task_separate_goal task in
   let task =
-    (* Declare all logic symbols related to the introduced [caisar_op]. *)
+    (* Declare logic symbols introduced for classifiers and vectors. *)
     Term.Hls.fold
       (fun ls _ task ->
-        let decl = Decl.create_param_decl ls in
-        Task.add_decl task decl)
+        if Language.mem_vector ls || Language.mem_nn_classifier ls
+        then
+          let decl = Decl.create_param_decl ls in
+          Task.add_decl task decl
+        else task)
       caisar_env.caisar_op_of_ls task
   in
   let task = Task.(add_prop_decl task Pgoal g f) in
diff --git a/src/language.ml b/src/language.ml
index 3499f00..75a9dac 100644
--- a/src/language.ml
+++ b/src/language.ml
@@ -183,6 +183,7 @@ let create_vector =
       ls))
 
 let lookup_vector = Term.Hls.find_opt vectors
+let mem_vector = Term.Hls.mem vectors
 
 (* -- Classifier *)
 
@@ -263,3 +264,4 @@ let create_onnx_classifier =
       ls))
 
 let lookup_nn_classifier = Term.Hls.find_opt nn_classifiers
+let mem_nn_classifier = Term.Hls.mem nn_classifiers
diff --git a/src/language.mli b/src/language.mli
index afa4a11..52ed49a 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -69,6 +69,7 @@ type vector = Term.lsymbol
 
 val create_vector : Env.env -> int -> vector
 val lookup_vector : vector -> int option
+val mem_vector : vector -> bool
 
 (** -- Classifier *)
 
@@ -86,3 +87,4 @@ type nn_classifier = Term.lsymbol
 val create_nnet_classifier : Env.env -> string -> nn_classifier
 val create_onnx_classifier : Env.env -> string -> nn_classifier
 val lookup_nn_classifier : nn_classifier -> nn option
+val mem_nn_classifier : nn_classifier -> bool
-- 
GitLab