From eb94f8c81c273a6e7c81c45f1afe9dcadda4cbe9 Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Fri, 17 Mar 2023 11:36:00 +0100 Subject: [PATCH] [interpretation] wip2. --- src/interpretation.ml | 117 ++++++++++++++++++++++----------- tests/interpretation_dataset.t | 4 +- 2 files changed, 82 insertions(+), 39 deletions(-) diff --git a/src/interpretation.ml b/src/interpretation.ml index b5e51e6..fbd18c2 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -41,6 +41,14 @@ type caisar_op = [@printer fun fmt (t1, t2) -> Fmt.pf fmt "%a[%a]" Pretty.print_term t1 Pretty.print_term t2] + | EqualShape of Term.term * Term.term + [@printer + fun fmt (t1, t2) -> + Fmt.pf fmt "EqShape %a %a" Pretty.print_term t1 Pretty.print_term t2] + | ValidIndex of Term.term * Term.term + [@printer + fun fmt (t1, t2) -> + Fmt.pf fmt "ValidIdx %a %a" Pretty.print_term t1 Pretty.print_term t2] [@@deriving show] type caisar_env = { @@ -90,21 +98,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = Fmt.(option ~none:nop Pretty.print_ty) ty; match vl with - | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] -> - Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; - Term (term_of_caisar_op engine (VGet (t1, t2)) ty) | [ - Term ({ t_node = Tapp (ls, _); _ } as t1); + Term ({ t_node = Tapp (lsapp, _); _ } as t1); Term ({ t_node = Tconst (ConstInt i); _ } as t2); - ] -> + ] -> ( Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; - let t_features, t_label = - let row = - match caisar_op_of_ls engine ls with - | Dataset (CSV csv) -> List.nth_exn csv (Number.to_small_integer i) - | Data _ | Classifier _ | ClassifierApp (_, _) | VGet (_, _) -> - assert false - in + match caisar_op_of_ls engine lsapp with + | Dataset (CSV csv) -> + let row = List.nth_exn csv (Number.to_small_integer i) in let label, features = match row with | [] | [ _ ] -> assert false @@ -115,10 +116,24 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = | Some { ty_node = Tyapp (_, [ a; _ ]); _ } -> Some a | _ -> assert false in - ( term_of_caisar_op engine (Data features) ty_features, - Term.t_int_const (BigInt.of_int (Int.of_string label)) ) - in - Term (Term.t_tuple [ t_features; t_label ]) + let t_features, t_label = + ( term_of_caisar_op engine (Data features) ty_features, + Term.t_int_const (BigInt.of_int (Int.of_string label)) ) + in + Term (Term.t_tuple [ t_features; t_label ]) + | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ]) + | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ -> + assert false) + | [ + Term ({ t_node = Tapp (lsapp, _); _ } as t1); + Term ({ t_node = Tvar _; _ } as t2); + ] -> ( + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + match caisar_op_of_ls engine lsapp with + | Dataset _ -> assert false + | ClassifierApp (_, _) -> Term (Term.t_app_infer ls [ t1; t2 ]) + | Data _ | Classifier _ | VGet (_, _) | EqualShape _ | ValidIndex _ -> + assert false) | _ -> invalid_arg (error_message ls) in let length : _ CRE.builtin = @@ -130,25 +145,48 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = | [ Term { t_node = Tapp (ls, []); _ } ] -> ( match caisar_op_of_ls engine ls with | Dataset (CSV csv) -> Int (BigInt.of_int (Csv.lines csv)) - | Data _ | Classifier _ | ClassifierApp _ | VGet _ -> assert false) + | Data _ | Classifier _ | ClassifierApp _ | VGet _ | EqualShape _ + | ValidIndex _ -> + assert false) + | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (term_of_caisar_op engine (VGet (t1, t2)) ty) | _ -> invalid_arg (error_message ls) in (* Tensor *) - (* let valid_index : _ CRE.builtin = *) - (* fun _engine ls _vl ty -> *) - (* Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls *) - (* Fmt.(option ~none:nop Pretty.print_ty) *) - (* ty; *) - (* Term Term.t_true *) - (* in *) - (* let equal_shape : _ CRE.builtin = *) - (* fun _engine ls _vl ty -> *) - (* Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls *) - (* Fmt.(option ~none:nop Pretty.print_ty) *) - (* ty; *) - (* Term Term.t_true *) - (* in *) + let _valid_index : _ CRE.builtin = + fun engine ls vl ty -> + Fmt.pr "--@.valid_index: ls:%a , ty:%a@." Pretty.print_ls ls + Fmt.(option ~none:nop Pretty.print_ty) + ty; + match vl with + | [ + Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2); + ] + | [ Term t1; Term ({ t_node = Tvar _; _ } as t2) ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (term_of_caisar_op engine (ValidIndex (t1, t2)) ty) + (* Term Term.t_true *) + | _ -> invalid_arg (error_message ls) + in + let _equal_shape : _ CRE.builtin = + fun engine ls vl ty -> + Fmt.pr "--@.equal_shape: ls:%a , ty:%a@." Pretty.print_ls ls + Fmt.(option ~none:nop Pretty.print_ty) + ty; + match vl with + | [ + Term ({ t_node = Tvar _; _ } as t1); Term ({ t_node = Tvar _; _ } as t2); + ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (Term.t_app_infer ls [ t1; t2 ]) + | [ Term t1; Term t2 ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (term_of_caisar_op engine (EqualShape (t1, t2)) ty) + (* Term Term.t_true *) + | _ -> invalid_arg (error_message ls) + in (* Classifier *) let read_classifier : _ CRE.builtin = @@ -176,6 +214,12 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ty; match vl with | [ Term ({ t_node = Tvar _; _ } as t1); Term t2 ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (Term.t_app_infer ls [ t1; t2 ]) + | [ Term ({ t_node = Tapp (_lsapp, _); _ } as t1); Term t2 ] -> + Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; + Term (Term.t_app_infer ls [ t1; t2 ]) + | [ Term t1; Term t2 ] -> Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; Term (term_of_caisar_op engine (ClassifierApp (t1, t2)) ty) | _ -> invalid_arg (error_message ls) @@ -206,12 +250,11 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = "Vector", [], [ (Ident.op_get "" (* ([]) *), None, vget); ("length", None, length) ] ); - (* ( [ "interpretation" ], *) - (* "Tensor", *) - (* [], *) - (* [ ("valid_index", None, valid_index); ("equal_shape", None, equal_shape) - ] *) - (* ); *) + ( [ "interpretation" ], + "Tensor", + [], + [ (* ("valid_index", None, valid_index); *) + (* ("equal_shape", None, equal_shape); *) ] ); ( [ "interpretation" ], "Classifier", [], diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index a933f78..ca4ded6 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -20,11 +20,11 @@ Test interpret on dataset > predicate valid_image (i: image) = > forall v: index. valid_index i v -> (0.0: t) .<= i#v .<= (1.0: t) > - > predicate valid_label (l: label_) = 0 <= l <= 1 + > predicate valid_label (l: label_) = 0 <= l <= 2 > > predicate advises (c: classifier) (i: image) (l: label_) = > valid_label l -> - > forall j: int. valid_label j /\ j <> l -> (c@@i)[l] .> (c@@i)[j] + > forall j: int. valid_label j -> j <> l -> (c@@i)[l] .> (c@@i)[j] > > predicate bounded_by_epsilon (i: image) (eps: t) = > forall v: index. valid_index i v -> .- eps .<= i#v .<= eps -- GitLab