Skip to content
Snippets Groups Projects
language.ml 9.33 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2022                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Why3
open Base

(* Support for several model formats: *)
(* - NNet and ONNX for neural networks *)
(* - OVO for SVM *)

type nn_shape = {
  nb_inputs : int;
  nb_outputs : int;
  ty_data : Ty.ty;
  filename : string;
  nier : Onnx.G.t option;
}

type svm_shape = {
  nb_inputs : int;
  nb_outputs : int;
  filename : string;
}

let loaded_nets = Term.Hls.create 10
let loaded_svms = Term.Hls.create 10
let lookup_loaded_nets = Term.Hls.find_opt loaded_nets
let lookup_loaded_svms = Term.Hls.find_opt loaded_svms

let register_nn_as_tuple env nb_inputs nb_outputs filename ?nier mstr =
  let name = "AsTuple" in
  let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
  let nn = Pmodule.read_module env [ "caisar" ] "NN" in
  let th_uc = Pmodule.use_export th_uc nn in
  let ty_data =
    Ty.ty_app Theory.(ns_find_ts nn.mod_theory.th_export [ "input_type" ]) []
  in
  let ls_nn_apply =
    let f _ = ty_data in
    Term.create_fsymbol
      (Ident.id_fresh "nn_apply")
      (List.init nb_inputs ~f)
      (Ty.ty_tuple (List.init nb_outputs ~f))
  in
  Term.Hls.add loaded_nets ls_nn_apply
    { filename; nb_inputs; nb_outputs; ty_data; nier };
  let th_uc =
    Pmodule.add_pdecl ~vc:false th_uc
      (Pdecl.create_pure_decl (Decl.create_param_decl ls_nn_apply))
  in
  Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr

let register_nn_as_array env nb_inputs nb_outputs filename ?nier mstr =
  let name = "AsArray" in
  let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
  let nn = Pmodule.read_module env [ "caisar" ] "DatasetClassification" in
  let th_uc = Pmodule.use_export th_uc nn in
  let ty_data =
    Ty.ty_app Theory.(ns_find_ts nn.mod_theory.th_export [ "model" ]) []
  in
  let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] ty_data in
  Term.Hls.add loaded_nets ls_model
    { filename; nb_inputs; nb_outputs; ty_data; nier };
  let th_uc =
    Pmodule.add_pdecl ~vc:false th_uc
      (Pdecl.create_pure_decl (Decl.create_param_decl ls_model))
  in
  Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr

let register_svm_as_array env nb_inputs nb_outputs filename mstr =
  let name = "AsArray" in
  let th_uc = Pmodule.create_module env (Ident.id_fresh name) in
  let svm = Pmodule.read_module env [ "caisar" ] "DatasetClassification" in
  let th_uc = Pmodule.use_export th_uc svm in
  let svm_type =
    Ty.ty_app Theory.(ns_find_ts svm.mod_theory.th_export [ "model" ]) []
  in
  let ls_model = Term.create_fsymbol (Ident.id_fresh "model") [] svm_type in
  Term.Hls.add loaded_svms ls_model { filename; nb_inputs; nb_outputs };
  let th_uc =
    Pmodule.add_pdecl ~vc:false th_uc
      (Pdecl.create_pure_decl (Decl.create_param_decl ls_model))
  in
  Wstdlib.Mstr.add name (Pmodule.close_module th_uc) mstr

let nnet_parser =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module String) in
    Hashtbl.findi_or_add h ~default:(fun filename ->
      let model = Nnet.parse ~permissive:true filename in
      match model with
      | Error s -> Loc.errorm "%s" s
      | Ok { n_inputs; n_outputs; _ } ->
        Wstdlib.Mstr.empty
        |> register_nn_as_tuple env n_inputs n_outputs filename
        |> register_nn_as_array env n_inputs n_outputs filename))

let onnx_parser =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module String) in
    Hashtbl.findi_or_add h ~default:(fun filename ->
      let model = Onnx.parse filename in
      match model with
      | Error s -> Loc.errorm "%s" s
      | Ok { n_inputs; n_outputs; nier } ->
        let nier =
          match nier with
          | Error msg ->
            Logs.warn (fun m ->
              m "Cannot build network intermediate representation:@ %s" msg);
            None
          | Ok nier -> Some nier
        in
        Wstdlib.Mstr.empty
        |> register_nn_as_tuple env n_inputs n_outputs filename ?nier
        |> register_nn_as_array env n_inputs n_outputs filename ?nier))

let ovo_parser =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module String) in
    Hashtbl.findi_or_add h ~default:(fun filename ->
      let model = Ovo.parse filename in
      match model with
      | Error s -> Loc.errorm "%s" s
      | Ok { n_inputs; n_outputs } ->
        register_svm_as_array env n_inputs n_outputs filename Wstdlib.Mstr.empty))

let register_nnet_support () =
  Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
    "NNet" [ "nnet" ] (fun env _ filename _ -> nnet_parser env filename)

let register_onnx_support () =
  Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ]
    (fun env _ filename _ -> onnx_parser env filename)

let register_ovo_support () =
  Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
    (fun env _ filename _ -> ovo_parser env filename)

(* -- Vector *)

let vectors = Term.Hls.create 10

let vector_elt_ty env =
  let th = Env.read_theory env [ "ieee_float" ] "Float64" in
  Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []

let create_vector =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module Int) in
    let ty_elt = vector_elt_ty env in
    let ty =
      let th = Env.read_theory env [ "interpretation" ] "Vector" in
      Ty.ty_app (Theory.ns_find_ts th.th_export [ "vector" ]) [ ty_elt ]
    in
    Hashtbl.findi_or_add h ~default:(fun length ->
      let ls =
        let id = Ident.id_fresh "vector" in
        Term.create_fsymbol id (List.init length ~f:(fun _ -> ty_elt)) ty
      in
      Term.Hls.add vectors ls length;
      ls))

let lookup_vector = Term.Hls.find_opt vectors
let mem_vector = Term.Hls.mem vectors

(* -- Classifier *)

type nn = {
  nn_nb_inputs : int;
  nn_nb_outputs : int;
  nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
  nn_filename : string;
  nn_nier : Onnx.G.t option; [@opaque]
}
[@@deriving show]

let nets = Term.Hls.create 10

let fresh_nn_ls env name =
  let ty =
    let th = Env.read_theory env [ "interpretation" ] "NeuralNetwork" in
    Ty.ty_app (Theory.ns_find_ts th.th_export [ "nn" ]) []
  in
  let id = Ident.id_fresh name in
  Term.create_fsymbol id [] ty

let create_nn_nnet =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module String) in
    let ty_elt =
      let th = Env.read_theory env [ "ieee_float" ] "Float64" in
      Ty.ty_app (Theory.ns_find_ts th.th_export [ "t" ]) []
    in
    Hashtbl.findi_or_add h ~default:(fun filename ->
      let ls = fresh_nn_ls env "nnet_nn" in
      let nn =
        let model = Nnet.parse ~permissive:true filename in
        match model with
        | Error s -> Loc.errorm "%s" s
        | Ok { n_inputs; n_outputs; _ } ->
          {
            nn_nb_inputs = n_inputs;
            nn_nb_outputs = n_outputs;
            nn_ty_elt = ty_elt;
            nn_filename = filename;
            nn_nier = None;
          }
      in
      Term.Hls.add nets ls nn;
      ls))

let create_nn_onnx =
  Env.Wenv.memoize 13 (fun env ->
    let h = Hashtbl.create (module String) in
    let ty_elt = vector_elt_ty env in
    Hashtbl.findi_or_add h ~default:(fun filename ->
      let ls = fresh_nn_ls env "onnx_nn" in
      let onnx =
        let model = Onnx.parse filename in
        match model with
        | Error s -> Loc.errorm "%s" s
        | Ok { n_inputs; n_outputs; nier } ->
          let nier =
            match nier with
            | Error msg ->
              Logs.warn (fun m ->
                m "Cannot build network intermediate representation:@ %s" msg);
              None
            | Ok nier -> Some nier
          in
          {
            nn_nb_inputs = n_inputs;
            nn_nb_outputs = n_outputs;
            nn_ty_elt = ty_elt;
            nn_filename = filename;
            nn_nier = nier;
          }
      in
      Term.Hls.add nets ls onnx;
      ls))

let lookup_nn = Term.Hls.find_opt nets
let mem_nn = Term.Hls.mem nets