Skip to content

Commit 8ca7f60

Browse files
add arg to MemmappedDataset (#295)
* add arg to MemmappedDataset * remove_ref_enegy not needed * add Ace to test_dataset.py * to black
1 parent fdd4dac commit 8ca7f60

File tree

2 files changed

+123
-29
lines changed

2 files changed

+123
-29
lines changed

tests/test_datasets.py

Lines changed: 123 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,27 @@
88
from os.path import join
99
import numpy as np
1010
import psutil
11-
from torchmdnet.datasets import Custom, HDF5
11+
from torchmdnet.datasets import Custom, HDF5, Ace
1212
from torchmdnet.utils import write_as_hdf5
1313
import h5py
1414
import glob
1515

16+
1617
def write_sample_npy_files(energy, forces, tmpdir, num_files):
1718
# set up necessary files
1819
n_atoms = np.random.randint(2, 10, size=num_files)
1920
num_samples = np.random.randint(10, 100, size=num_files)
20-
#n_atoms repeated num_samples times for each file
21+
# n_atoms repeated num_samples times for each file
2122
for i in range(num_files):
2223
n_atoms_i = n_atoms[i]
2324
num_samples_i = num_samples[i]
2425
np.save(
25-
join(tmpdir, f"coords_{i}.npy"), np.random.normal(size=(num_samples_i, n_atoms_i, 3)).astype(np.float32)
26+
join(tmpdir, f"coords_{i}.npy"),
27+
np.random.normal(size=(num_samples_i, n_atoms_i, 3)).astype(np.float32),
28+
)
29+
np.save(
30+
join(tmpdir, f"embed_{i}.npy"), np.random.randint(0, 100, size=n_atoms_i)
2631
)
27-
np.save(join(tmpdir, f"embed_{i}.npy"), np.random.randint(0, 100, size=n_atoms_i))
2832
if energy:
2933
np.save(
3034
join(tmpdir, f"energy_{i}.npy"),
@@ -41,6 +45,7 @@ def write_sample_npy_files(energy, forces, tmpdir, num_files):
4145
n_atoms_per_sample = np.array(n_atoms_per_sample)
4246
return n_atoms_per_sample
4347

48+
4449
@mark.parametrize("energy", [True, False])
4550
@mark.parametrize("forces", [True, False])
4651
@mark.parametrize("num_files", [1, 3])
@@ -78,12 +83,22 @@ def test_custom(energy, forces, num_files, preload, tmpdir):
7883
# Assert shapes of whole dataset:
7984
for i in range(len(data)):
8085
n_atoms_i = n_atoms_per_sample[i]
81-
assert np.array(data[i].z).shape == (n_atoms_i,), "Dataset has incorrect atom numbers shape"
82-
assert np.array(data[i].pos).shape == (n_atoms_i, 3), "Dataset has incorrect coords shape"
86+
assert np.array(data[i].z).shape == (
87+
n_atoms_i,
88+
), "Dataset has incorrect atom numbers shape"
89+
assert np.array(data[i].pos).shape == (
90+
n_atoms_i,
91+
3,
92+
), "Dataset has incorrect coords shape"
8393
if energy:
84-
assert np.array(data[i].y).shape == (1,), "Dataset has incorrect energy shape"
94+
assert np.array(data[i].y).shape == (
95+
1,
96+
), "Dataset has incorrect energy shape"
8597
if forces:
86-
assert np.array(data[i].neg_dy).shape == (n_atoms_i, 3), "Dataset has incorrect forces shape"
98+
assert np.array(data[i].neg_dy).shape == (
99+
n_atoms_i,
100+
3,
101+
), "Dataset has incorrect forces shape"
87102
# Assert sample has the correct values
88103

89104
# get the reference values from coords_0.npy and embed_0.npy
@@ -98,18 +113,19 @@ def test_custom(energy, forces, num_files, preload, tmpdir):
98113
ref_neg_dy = np.load(join(tmpdir, "forces_0.npy"))[0] if forces else None
99114
assert np.allclose(sample.neg_dy, ref_neg_dy), "Sample has incorrect forces"
100115

116+
101117
@mark.parametrize(("energy", "forces"), [(True, False), (False, True), (True, True)])
102118
def test_write_as_hdf5(energy, forces, tmpdir):
103119
# set up necessary files
104120
num_files = 3
105121
write_sample_npy_files(energy, forces, tmpdir, num_files)
106-
files={}
107-
files["pos"]=sorted(glob.glob(join(tmpdir, "coords*")))
108-
files["z"]=sorted(glob.glob(join(tmpdir, "embed*")))
122+
files = {}
123+
files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
124+
files["z"] = sorted(glob.glob(join(tmpdir, "embed*")))
109125
if energy:
110-
files["y"]=sorted(glob.glob(join(tmpdir, "energy*")))
126+
files["y"] = sorted(glob.glob(join(tmpdir, "energy*")))
111127
if forces:
112-
files["neg_dy"]=sorted(glob.glob(join(tmpdir, "forces*")))
128+
files["neg_dy"] = sorted(glob.glob(join(tmpdir, "forces*")))
113129
write_as_hdf5(files, join(tmpdir, "test.hdf5"))
114130
# Assert file is present in the disk
115131
assert os.path.isfile(join(tmpdir, "test.hdf5")), "HDF5 file was not created"
@@ -120,31 +136,45 @@ def test_write_as_hdf5(energy, forces, tmpdir):
120136
pos_npy = np.load(files["pos"][i])
121137
n_samples = pos_npy.shape[0]
122138
n_atoms_i = pos_npy.shape[1]
123-
assert np.array(data[str(i)]["types"]).shape == (n_samples, n_atoms_i,), "Dataset has incorrect atom numbers shape"
124-
assert np.array(data[str(i)]["pos"]).shape == (n_samples, n_atoms_i, 3), "Dataset has incorrect coords shape"
139+
assert np.array(data[str(i)]["types"]).shape == (
140+
n_samples,
141+
n_atoms_i,
142+
), "Dataset has incorrect atom numbers shape"
143+
assert np.array(data[str(i)]["pos"]).shape == (
144+
n_samples,
145+
n_atoms_i,
146+
3,
147+
), "Dataset has incorrect coords shape"
125148
if energy:
126-
assert np.array(data[str(i)]["energy"]).shape == (n_samples, 1,), "Dataset has incorrect energy shape"
149+
assert np.array(data[str(i)]["energy"]).shape == (
150+
n_samples,
151+
1,
152+
), "Dataset has incorrect energy shape"
127153
if forces:
128-
assert np.array(data[str(i)]["forces"]).shape == (n_samples, n_atoms_i, 3), "Dataset has incorrect forces shape"
154+
assert np.array(data[str(i)]["forces"]).shape == (
155+
n_samples,
156+
n_atoms_i,
157+
3,
158+
), "Dataset has incorrect forces shape"
159+
129160

130161
@mark.parametrize("preload", [True, False])
131162
@mark.parametrize(("energy", "forces"), [(True, False), (False, True), (True, True)])
132163
@mark.parametrize("num_files", [1, 3])
133164
def test_hdf5(preload, energy, forces, num_files, tmpdir):
134165
# set up necessary files
135166
n_atoms_per_sample = write_sample_npy_files(energy, forces, tmpdir, num_files)
136-
files={}
137-
files["pos"]=sorted(glob.glob(join(tmpdir, "coords*")))
138-
files["z"]=sorted(glob.glob(join(tmpdir, "embed*")))
167+
files = {}
168+
files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
169+
files["z"] = sorted(glob.glob(join(tmpdir, "embed*")))
139170
if energy:
140-
files["y"]=sorted(glob.glob(join(tmpdir, "energy*")))
171+
files["y"] = sorted(glob.glob(join(tmpdir, "energy*")))
141172
if forces:
142-
files["neg_dy"]=sorted(glob.glob(join(tmpdir, "forces*")))
173+
files["neg_dy"] = sorted(glob.glob(join(tmpdir, "forces*")))
143174
write_as_hdf5(files, join(tmpdir, "test.hdf5"))
144175
# Assert file is present in the disk
145176
assert os.path.isfile(join(tmpdir, "test.hdf5")), "HDF5 file was not created"
146177

147-
148178
data = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=256 if preload else 0)
149179

150180
assert len(data) == len(n_atoms_per_sample), "Number of samples does not match"
@@ -159,12 +189,22 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir):
159189
# Assert shapes of whole dataset:
160190
for i in range(len(data)):
161191
n_atoms_i = n_atoms_per_sample[i]
162-
assert np.array(data[i].z).shape == (n_atoms_i,), "Dataset has incorrect atom numbers shape"
163-
assert np.array(data[i].pos).shape == (n_atoms_i, 3), "Dataset has incorrect coords shape"
192+
assert np.array(data[i].z).shape == (
193+
n_atoms_i,
194+
), "Dataset has incorrect atom numbers shape"
195+
assert np.array(data[i].pos).shape == (
196+
n_atoms_i,
197+
3,
198+
), "Dataset has incorrect coords shape"
164199
if energy:
165-
assert np.array(data[i].y).shape == (1,), "Dataset has incorrect energy shape"
200+
assert np.array(data[i].y).shape == (
201+
1,
202+
), "Dataset has incorrect energy shape"
166203
if forces:
167-
assert np.array(data[i].neg_dy).shape == (n_atoms_i, 3), "Dataset has incorrect forces shape"
204+
assert np.array(data[i].neg_dy).shape == (
205+
n_atoms_i,
206+
3,
207+
), "Dataset has incorrect forces shape"
168208
# Assert sample has the correct values
169209
# get the reference values from coords_0.npy and embed_0.npy
170210
ref_pos = np.load(join(tmpdir, "coords_0.npy"))[0]
@@ -179,7 +219,6 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir):
179219
assert np.allclose(sample.neg_dy, ref_neg_dy), "Sample has incorrect forces"
180220

181221

182-
183222
def test_hdf5_multiprocessing(tmpdir, num_entries=100):
184223
# generate sample data
185224
z = np.zeros(num_entries)
@@ -202,3 +241,59 @@ def test_hdf5_multiprocessing(tmpdir, num_entries=100):
202241
dset = HDF5(join(tmpdir, "test_hdf5_multiprocessing.h5"))
203242

204243
assert len(proc.open_files()) == n_open, "creating the dataset object opened a file"
244+
245+
246+
def test_ace(tmpdir):
247+
# Test Version 1.0
248+
tmpfilename = join(tmpdir, "molecule.h5")
249+
f = h5py.File(tmpfilename, "w")
250+
f.attrs["layout"] = "Ace"
251+
f.attrs["layout_version"] = "1.0"
252+
f.attrs["name"] = "sample_molecule_data"
253+
for m in range(3): # Three molecules
254+
mol = f.create_group(f"mol_{m+1}")
255+
mol["atomic_numbers"] = [1, 6, 8] # H, C, O
256+
mol["formal_charges"] = [0, 0, 0] # Neutral charges
257+
confs = mol.create_group("conformations")
258+
for i in range(2): # Two conformations
259+
conf = confs.create_group(f"conf_{i+1}")
260+
conf["positions"] = np.random.random((3, 3))
261+
conf["positions"].attrs["units"] = "Å"
262+
conf["formation_energy"] = np.random.random()
263+
conf["formation_energy"].attrs["units"] = "eV"
264+
conf["forces"] = np.random.random((3, 3))
265+
conf["forces"].attrs["units"] = "eV/Å"
266+
conf["partial_charges"] = np.random.random(3)
267+
conf["partial_charges"].attrs["units"] = "e"
268+
conf["dipole_moment"] = np.random.random(3)
269+
conf["dipole_moment"].attrs["units"] = "e*Å"
270+
271+
dataset = Ace(root=tmpdir, paths=tmpfilename)
272+
assert len(dataset) == 6
273+
f.flush()
274+
f.close()
275+
# Test Version 2.0
276+
tmpfilename_v2 = join(tmpdir, "molecule_v2.h5")
277+
f2 = h5py.File(tmpfilename_v2, "w")
278+
f2.attrs["layout"] = "Ace"
279+
f2.attrs["layout_version"] = "2.0"
280+
f2.attrs["name"] = "sample_molecule_data_v2"
281+
master_mol_group = f2.create_group("master_molecule_group")
282+
for m in range(3): # Three molecules
283+
mol = master_mol_group.create_group(f"mol_{m+1}")
284+
mol["atomic_numbers"] = [1, 6, 8] # H, C, O
285+
mol["formal_charges"] = [0, 0, 0] # Neutral charges
286+
mol["positions"] = np.random.random((2, 3, 3)) # Two conformations
287+
mol["positions"].attrs["units"] = "Å"
288+
mol["formation_energies"] = np.random.random(2)
289+
mol["formation_energies"].attrs["units"] = "eV"
290+
mol["forces"] = np.random.random((2, 3, 3))
291+
mol["forces"].attrs["units"] = "eV/Å"
292+
mol["partial_charges"] = np.random.random((2, 3))
293+
mol["partial_charges"].attrs["units"] = "e"
294+
mol["dipole_moment"] = np.random.random((2, 3))
295+
mol["dipole_moment"].attrs["units"] = "e*Å"
296+
dataset_v2 = Ace(root=tmpdir, paths=tmpfilename_v2)
297+
assert len(dataset_v2) == 6
298+
f2.flush()
299+
f2.close()

torchmdnet/datasets/ace.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def __init__(
146146
transform,
147147
pre_transform,
148148
pre_filter,
149-
remove_ref_energy=False,
150149
properties=("y", "neg_dy", "q", "pq", "dp"),
151150
)
152151

0 commit comments

Comments
 (0)