From d0aa71074e8dc9113b0f6eeb09767b4b6668cd23 Mon Sep 17 00:00:00 2001 From: Alban Grastien <alban.grastien@cea.fr> Date: Fri, 12 Jul 2024 14:43:32 +0200 Subject: [PATCH] Fixing bug `<` -> `>`. --- lib/nir/node.ml | 22 ++-------------------- lib/nir/node.mli | 4 ++-- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/lib/nir/node.ml b/lib/nir/node.ml index f09d211..b162f2f 100644 --- a/lib/nir/node.ml +++ b/lib/nir/node.ml @@ -511,7 +511,7 @@ let partial_dot_product ?shp arr1 arr2 first last = let ioob str = failwith @@ "Index out of bound for arr" ^ str in if last > Array.length arr1 then ioob "1" else if last > Array.length arr2 then ioob "2" - else if last < first then + else if last > first then let rec aux index acc = if index = last then acc else let acc = acc + (arr1.(index) * arr2.(index)) @@ -519,7 +519,7 @@ let partial_dot_product ?shp arr1 arr2 first last = aux Int.(index+1) acc in aux Int.(first+1) (arr1.(first) * arr2.(first)) - else + else (* nothing to include, returns a tensor of 0s. *) let actual_shape = if Array.length arr1 <> 0 then compute_shape arr1.(0) @@ -530,21 +530,3 @@ let partial_dot_product ?shp arr1 arr2 first last = | None -> failwith "Cannot determine shape of tensor" in create @@ (Constant { data = Gentensor.create_const_float actual_shape 0.0}) - (* - if Array.length arr1 < last && Array.length arr2 < last then - failwith "Index out of bound" - else - let zero_node = - create @@ Constant { data = Gentensor.create_1_float 0.0 } - in - let rec aux index acc = - if index = last - then acc - else - let prod = arr1.(index) * arr2.(index) in - let new_acc = acc + prod in - aux Int.(index + 1) new_acc - in - aux first zero_node -*) - \ No newline at end of file diff --git a/lib/nir/node.mli b/lib/nir/node.mli index 68918fe..83f88c8 100644 --- a/lib/nir/node.mli +++ b/lib/nir/node.mli @@ -182,12 +182,12 @@ If [ns] is empty, this returns a tensor of shape [shp] filled with 0s. By default, [shp] is a single float. *) val partial_dot_product : ?shp:Shape.t -> t array -> t array -> int -> int -> t -(** [partial_dot_product arr1 arr2 first last] +(** [partial_dot_product shp arr1 arr2 first last] where [arr1 = [|n11,n12,...,n1k1|]] and [arr2 = [|n21,n22,...,n2k2|]] is a node corresponding to [(n1first * n2first) + (n1first+1 * n2first+1) + ... + (n1last-1 * n2last-1)] if this exists. -It is assumed that [arr1] and [arr2] contain tensors of similar shape. +It is assumed that [arr1] and [arr2] contain tensors with same shape. Edge cases include: {ul -- GitLab