Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
caisar
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pub
caisar
Commits
be0a54f9
Commit
be0a54f9
authored
1 year ago
by
Michele Alberti
Committed by
François Bobot
10 months ago
Browse files
Options
Downloads
Patches
Plain Diff
[ir] Revise API for creating a gather via matmul encoding.
parent
4b0a91c1
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
lib/ir/nier_simple.ml
+13
-10
13 additions, 10 deletions
lib/ir/nier_simple.ml
lib/ir/nier_simple.mli
+5
-2
5 additions, 2 deletions
lib/ir/nier_simple.mli
src/transformations/native_nn_prover.ml
+1
-1
1 addition, 1 deletion
src/transformations/native_nn_prover.ml
with
19 additions
and
13 deletions
lib/ir/nier_simple.ml
+
13
−
10
View file @
be0a54f9
...
@@ -408,12 +408,6 @@ module Node = struct
...
@@ -408,12 +408,6 @@ module Node = struct
ty
=
compute_ty
descr
;
ty
=
compute_ty
descr
;
}
}
let
gather_int
input
i
=
let
indices
=
create
(
Constant
{
data
=
GenTensor
.
create_1_int64
(
Int64
.
of_int
i
)
})
in
create
(
Gather
{
input
;
indices
;
axis
=
0
})
let
constant_int_array
a
=
let
constant_int_array
a
=
create
(
Constant
{
data
=
GenTensor
.
of_int_array
a
})
create
(
Constant
{
data
=
GenTensor
.
of_int_array
a
})
...
@@ -426,10 +420,10 @@ module Node = struct
...
@@ -426,10 +420,10 @@ module Node = struct
{
input
=
node
;
shape
=
constant_int_array
(
Shape
.
to_array
shape
)
})
{
input
=
node
;
shape
=
constant_int_array
(
Shape
.
to_array
shape
)
})
let
gather_int_as_matmul
input
i
=
let
gather_int_as_matmul
input
i
=
let
input
=
let
input
1
=
reshape
(
Shape
.
of_array
[
|
1
;
Shape
.
size
input
.
shape
|
])
input
reshape
(
Shape
.
of_array
[
|
1
;
Shape
.
size
input
.
shape
|
])
input
in
in
let
selector
=
Array
.
create
~
len
:
(
Shape
.
size
input
.
shape
)
Float
.
zero
in
let
selector
=
Array
.
create
~
len
:
(
Shape
.
size
input
1
.
shape
)
Float
.
zero
in
Array
.
set
selector
i
Float
.
one
;
Array
.
set
selector
i
Float
.
one
;
let
selector
=
let
selector
=
GenTensor
.
Float
GenTensor
.
Float
...
@@ -438,8 +432,17 @@ module Node = struct
...
@@ -438,8 +432,17 @@ module Node = struct
(
Bigarray
.
Array1
.
of_array
Float64
C_layout
selector
))
(
Bigarray
.
Array1
.
of_array
Float64
C_layout
selector
))
in
in
let
input2
=
create
(
Constant
{
data
=
selector
})
in
let
input2
=
create
(
Constant
{
data
=
selector
})
in
let
matmul
=
create
(
Matmul
{
input1
=
input
;
input2
})
in
let
result
=
create
(
Matmul
{
input1
;
input2
})
in
reshape
(
Shape
.
of_array
[
|
1
|
])
matmul
reshape
(
Shape
.
of_array
[
|
1
|
])
result
let
gather_int
?
(
encode
=
true
)
input
i
=
if
encode
then
gather_int_as_matmul
input
i
else
let
indices
=
create
(
Constant
{
data
=
GenTensor
.
create_1_int64
(
Int64
.
of_int
i
)
})
in
create
(
Gather
{
input
;
indices
;
axis
=
0
})
let
concat_0
=
function
let
concat_0
=
function
|
[
n
]
->
n
|
[
n
]
->
n
...
...
This diff is collapsed.
Click to expand it.
lib/ir/nier_simple.mli
+
5
−
2
View file @
be0a54f9
...
@@ -148,8 +148,11 @@ module Node : sig
...
@@ -148,8 +148,11 @@ module Node : sig
include
Base
.
Comparator
.
S
with
type
t
:=
t
include
Base
.
Comparator
.
S
with
type
t
:=
t
val
create
:
descr
->
t
val
create
:
descr
->
t
val
gather_int
:
t
->
int
->
t
val
gather_int_as_matmul
:
t
->
int
->
t
val
gather_int
:
?
encode
:
bool
->
t
->
int
->
t
(** create a node by selection at a given index. *)
(* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a
parameter, rather depend on prover. *)
val
constant_int_array
:
int
array
->
t
val
constant_int_array
:
int
array
->
t
(** create a node for a constant array *)
(** create a node for a constant array *)
...
...
This diff is collapsed.
Click to expand it.
src/transformations/native_nn_prover.ml
+
1
−
1
View file @
be0a54f9
...
@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
...
@@ -54,7 +54,7 @@ let create_new_nn env input_vars outputs : string =
let
get_input
=
let
get_input
=
Why3
.
Term
.
Hls
.
memo
10
(
fun
ls
->
Why3
.
Term
.
Hls
.
memo
10
(
fun
ls
->
let
i
=
Why3
.
Term
.
Mls
.
find_exn
UnknownLogicSymbol
ls
input_vars
in
let
i
=
Why3
.
Term
.
Mls
.
find_exn
UnknownLogicSymbol
ls
input_vars
in
Ir
.
Nier_simple
.
Node
.
gather_int
_as_matmul
input
i
)
Ir
.
Nier_simple
.
Node
.
gather_int
input
i
)
in
in
let
cache
=
Why3
.
Term
.
Hterm
.
create
17
in
let
cache
=
Why3
.
Term
.
Hterm
.
create
17
in
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
(* Instantiate the input of [old_nn] with the [old_nn_args] terms transformed
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment