-
Michele Alberti authoredMichele Alberti authored
language.ml 9.33 KiB
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(* Copyright (C) 2022 *)
(* CEA (Commissariat à l'énergie atomique et aux énergies *)
(* alternatives) *)
(* *)
(* You can redistribute it and/or modify it under the terms of the GNU *)
(* Lesser General Public License as published by the Free Software *)
(* Foundation, version 2.1. *)
(* *)
(* It is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)
(* GNU Lesser General Public License for more details. *)
(* *)
(* See the GNU Lesser General Public License version 2.1 *)
(* for more details (enclosed in the file licenses/LGPLv2.1). *)
(* *)
(**************************************************************************)
open Why3
open Base
(* Support for several model formats: *)
(* - NNet and ONNX for neural networks *)
(* - OVO for SVM *)
type nn_shape = {
nb_inputs : int;
nb_outputs : int;
ty_data : Ty.ty;
filename : string;
nier : Onnx.G.t option;
}
type svm_shape = {
nb_inputs : int;
nb_outputs : int;
filename : string;
}
let loaded_nets = Term.Hls.create 10
let loaded_svms = Term.Hls.create 10
let lookup_loaded_nets = Term.Hls.find_opt loaded_nets
let lookup_loaded_svms = Term.Hls.find_opt loaded_svms
let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr =
let name = "AsTuple" in
let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
let nn = Pmodule.read_module env [ "caisar" ] "NN" in
let th_uc = Pmodule.use_export th_uc nn in
let ty_data =
Ty.ty_app Theory.(ns_find_ts nn.mod_theory.th_export [ "input_type" ]) []
in
let ls_nn_apply =
let f _ = ty_data in
Term.create_fsymbol
(Ident.id_fresh "nn_apply")
(List.init nb_inputs ~f)
(Ty.ty_tuple (List.init nb_outputs ~f))
in
Term.Hls.add loaded_nets ls_nn_apply
{ filename; nb_inputs; nb_outputs; ty_data; nier };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_nn_apply))
in
Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr
let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr =
let name = "AsArray" in
let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
let nn = Pmodule.read_module env [ "caisar" ] "DatasetClassification" in
let th_uc = Pmodule.use_export th_uc nn in
let ty_data =
Ty.ty_app Theory.(ns_find_ts nn.mod_theory.th_export [ "model" ]) []
in
let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] ty_data in
Term.Hls.add loaded_nets ls_model
{ filename; nb_inputs; nb_outputs; ty_data; nier };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_model))
in
Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr
let register_svm_as_array env nb_inputs nb_outputs filename mstr =
let name = "AsArray" in
let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
let svm = Pmodule.read_module env [ "caisar" ] "DatasetClassification" in
let th_uc = Pmodule.use_export th_uc svm in
let svm_type =
Ty.ty_app Theory.(ns_find_ts svm.mod_theory.th_export [ "model" ]) []
in
let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] svm_type in
Term.Hls.add loaded_svms ls_model { filename; nb_inputs; nb_outputs };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_model))
in
Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr
let nnet_parser =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
Hashtbl.findi_or_add h ~default:(fun filename ->
let model = Nnet.parse ~permissive:true filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } ->
Wstdlib.Mstr.empty
|> register_nn_as_tuple env n_inputs n_outputs filename
|> register_nn_as_array env n_inputs n_outputs filename))
let onnx_parser =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
Hashtbl.findi_or_add h ~default:(fun filename ->
let model = Onnx.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; nier } ->
let nier =
match nier with
| Error msg ->
Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg);
None
| Ok nier -> Some nier
in
Wstdlib.Mstr.empty
|> register_nn_as_tuple env n_inputs n_outputs filename ?nier
|> register_nn_as_array env n_inputs n_outputs filename ?nier))
let ovo_parser =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
Hashtbl.findi_or_add h ~default:(fun filename ->
let model = Ovo.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs } ->
register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty))
let register_nnet_support () =
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
"NNet" [ "nnet" ] (fun env _ filename _ -> nnet_parser env filename)
let register_onnx_support () =
Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ]
(fun env _ filename _ -> onnx_parser env filename)
let register_ovo_support () =
Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
(fun env _ filename _ -> ovo_parser env filename)
(* -- Vector *)
let vectors = Term.Hls.create 10
let vector_elt_ty env =
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
let create_vector =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module Int) in
let ty_elt = vector_elt_ty env in
let ty =
let th = Env.read_theory env [ "interpretation" ] "Vector" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ]
in
Hashtbl.findi_or_add h ~default:(fun length ->
let ls =
let id = Ident.id_fresh "vector" in
Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty
in
Term.Hls.add vectors ls length;
ls))
let lookup_vector = Term.Hls.find_opt vectors
let mem_vector = Term.Hls.mem vectors
(* -- Classifier *)
type nn = {
nn_nb_inputs : int;
nn_nb_outputs : int;
nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
nn_filename : string;
nn_nier : Onnx.G.t option; [@opaque]
}
[@@deriving show]
let nets = Term.Hls.create 10
let fresh_nn_ls env name =
let ty =
let th = Env.read_theory env [ "interpretation" ] "NeuralNetwork" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "nn" ]) []
in
let id = Ident.id_fresh name in
Term.create_fsymbol id [] ty
let create_nn_nnet =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
let ty_elt =
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
in
Hashtbl.findi_or_add h ~default:(fun filename ->
let ls = fresh_nn_ls env "nnet_nn" in
let nn =
let model = Nnet.parse ~permissive:true filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } ->
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = None;
}
in
Term.Hls.add nets ls nn;
ls))
let create_nn_onnx =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
let ty_elt = vector_elt_ty env in
Hashtbl.findi_or_add h ~default:(fun filename ->
let ls = fresh_nn_ls env "onnx_nn" in
let onnx =
let model = Onnx.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; nier } ->
let nier =
match nier with
| Error msg ->
Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg);
None
| Ok nier -> Some nier
in
{
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = nier;
}
in
Term.Hls.add nets ls onnx;
ls))
let lookup_nn = Term.Hls.find_opt nets
let mem_nn = Term.Hls.mem nets