diff --git a/src/interpretation.ml b/src/interpretation.ml index b5d97f8bf3ad4e4fbc0131b2feaf5a7847bc613c..2fa1b821220308ef8fd66b6063ef4c5685e9d703 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -75,7 +75,10 @@ let ls_of_caisar_op engine op ty_args ty = Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () -> let id = Ident.id_fresh "caisar_op" in let ls = - match op with Vector v -> v | _ -> Term.create_lsymbol id ty_args ty + match op with + | Classifier (NNet c | ONNX c) -> c + | Vector v -> v + | _ -> Term.create_lsymbol id ty_args ty in (* Fmt.pr "ls: %a@." Pretty.print_ls ls; *) Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls; diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t index 9d1adea779c65d4f6194c49716869b474b797752..93fae2a3aa0c72020d0dc7c1d6aa51a8203deb9b 100644 --- a/tests/interpretation_acasxu.t +++ b/tests/interpretation_acasxu.t @@ -126,17 +126,19 @@ Test interpret on acasxu le (add RNE (mul RNE - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] (373.9499200000000200816430151462554931640625:t)) (7.51888402010059753166615337249822914600372314453125:t)) (1500.0:t) - vector, (Interpretation.Vector 5) - caisar_op, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) + vector, + (Interpretation.Vector 5) P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\ le (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) (60760.0:t) -> @@ -187,37 +189,38 @@ Test interpret on acasxu le (900.0:t) (add RNE (mul RNE caisar_v3 (1100.0:t)) (650.0:t)) -> le (960.0:t) (add RNE (mul RNE caisar_v4 (1200.0:t)) (600.0:t))) -> not (((lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2]) /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [3]) /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [4]) - vector, (Interpretation.Vector 5) - caisar_op1, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) + vector, + (Interpretation.Vector 5) diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 3ff66934d28c7e256eaa9c32cf648f4fb3c09bb9..de5f0863335ff23501b3cb0983313575a77be4f1 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -78,14 +78,18 @@ Test interpret on dataset (0.776470588000000017103729987866245210170745849609375:t)) (0.375:t) -> lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1]) /\ (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\ @@ -119,29 +123,32 @@ Test interpret on dataset (0.78431372499999996161790249971090815961360931396484375:t)) (0.375:t) -> lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] /\ lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) - caisar_op1, - (Interpretation.Data - (Interpretation.D_csv - ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) - caisar_op2, + caisar_op, (Interpretation.Data (Interpretation.D_csv ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) vector, (Interpretation.Vector 5) - caisar_op, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) - caisar_op3, - (Interpretation.Dataset <csv>) + caisar_op1, (Interpretation.Dataset <csv>) + caisar_op2, + (Interpretation.Data + (Interpretation.D_csv + ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"]))