Skip to content
Snippets Groups Projects
Commit f623a6bf authored by Julien Girard-Satabin's avatar Julien Girard-Satabin Committed by Aymeric Varasse
Browse files

[exps] Add data (and generation script) for arithmetic

parent ba38ee16
No related branches found
No related tags found
No related merge requests found
File added
import numpy as np
# Given a np array row, return a linear combination of the output
def f(x: np.ndarray) -> np.ndarray:
return np.array((x[0] - x[1] - x[2]))
if __name__ == "__main__":
arr1 = np.random.normal(size=(10000, 3))
arr2 = np.apply_along_axis(f, 1, arr1)
arr3 = np.random.normal(size=(10000, 3))
arr4 = np.apply_along_axis(f, 1, arr1)
np.save(file="data.npy", arr=arr1)
np.save(file="target.npy", arr=arr2)
np.save(file="test_data.npy", arr=arr3)
np.save(file="test_target.npy", arr=arr4)
File added
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment