diff --git a/examples/arithmetic/data.npy b/examples/arithmetic/data.npy new file mode 100644 index 0000000000000000000000000000000000000000..f630227a9a01ce9b10d6887e55d60181a36294f7 Binary files /dev/null and b/examples/arithmetic/data.npy differ diff --git a/examples/arithmetic/generate_dataset.py b/examples/arithmetic/generate_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0173b63250c42b53ece6e6590059eb245be53a27 --- /dev/null +++ b/examples/arithmetic/generate_dataset.py @@ -0,0 +1,17 @@ +import numpy as np + + +# Given a np array row, return a linear combination of the output +def f(x: np.ndarray) -> np.ndarray: + return np.array((x[0] - x[1] - x[2])) + + +if __name__ == "__main__": + arr1 = np.random.normal(size=(10000, 3)) + arr2 = np.apply_along_axis(f, 1, arr1) + arr3 = np.random.normal(size=(10000, 3)) + arr4 = np.apply_along_axis(f, 1, arr1) + np.save(file="data.npy", arr=arr1) + np.save(file="target.npy", arr=arr2) + np.save(file="test_data.npy", arr=arr3) + np.save(file="test_target.npy", arr=arr4) diff --git a/examples/arithmetic/target.npy b/examples/arithmetic/target.npy new file mode 100644 index 0000000000000000000000000000000000000000..d43f059ec3e7138fcd35f6d9c884d40115419ccf Binary files /dev/null and b/examples/arithmetic/target.npy differ diff --git a/examples/arithmetic/test_data.npy b/examples/arithmetic/test_data.npy new file mode 100644 index 0000000000000000000000000000000000000000..c4b7bd23c002d353eb5c85f186dbe3dd9d3845da Binary files /dev/null and b/examples/arithmetic/test_data.npy differ diff --git a/examples/arithmetic/test_target.npy b/examples/arithmetic/test_target.npy new file mode 100644 index 0000000000000000000000000000000000000000..f92ef12581a8bb43090b230658eaa4761828d231 Binary files /dev/null and b/examples/arithmetic/test_target.npy differ