Skip to content
Snippets Groups Projects
language.ml 1.48 KiB
(**************************************************************************)
(*                                                                        *)
(*  This file is part of Caisar.                                          *)
(*                                                                        *)
(**************************************************************************)

open Base

(* -- Register neural network format. *)

let nnet_parser env _ filename _ =
  let open Why3 in
  let header = Nnet.parse filename in
  match header with
  | Error s -> Loc.errorm "%s" s
  | Ok header ->
    let nnet = Pmodule.read_module env [ "caisar" ] "NNet" in
    let nnet_input_type =
      Ty.ty_app
        Theory.(ns_find_ts nnet.mod_theory.th_export [ "input_type" ])
        []
    in
    let id_as_tuple = Ident.id_fresh "AsTuple" in
    let th_uc = Pmodule.create_module env id_as_tuple in
    let th_uc = Pmodule.use_export th_uc nnet in
    let ls_nnet_apply =
      let f _ = nnet_input_type in
      Term.create_fsymbol
        (Ident.id_fresh "nnet_apply")
        (List.init header.n_inputs ~f)
        (Ty.ty_tuple (List.init header.n_outputs ~f))
    in

    let th_uc =
      Pmodule.add_pdecl ~vc:false th_uc
        (Pdecl.create_pure_decl @@ Decl.create_param_decl ls_nnet_apply)
    in
    Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)

let register () =
  Why3.(
    Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
      "NNet" [ "nnet" ] nnet_parser)