8
8
from os .path import join
9
9
import numpy as np
10
10
import psutil
11
- from torchmdnet .datasets import Custom , HDF5
11
+ from torchmdnet .datasets import Custom , HDF5 , Ace
12
12
from torchmdnet .utils import write_as_hdf5
13
13
import h5py
14
14
import glob
15
15
16
+
16
17
def write_sample_npy_files (energy , forces , tmpdir , num_files ):
17
18
# set up necessary files
18
19
n_atoms = np .random .randint (2 , 10 , size = num_files )
19
20
num_samples = np .random .randint (10 , 100 , size = num_files )
20
- #n_atoms repeated num_samples times for each file
21
+ # n_atoms repeated num_samples times for each file
21
22
for i in range (num_files ):
22
23
n_atoms_i = n_atoms [i ]
23
24
num_samples_i = num_samples [i ]
24
25
np .save (
25
- join (tmpdir , f"coords_{ i } .npy" ), np .random .normal (size = (num_samples_i , n_atoms_i , 3 )).astype (np .float32 )
26
+ join (tmpdir , f"coords_{ i } .npy" ),
27
+ np .random .normal (size = (num_samples_i , n_atoms_i , 3 )).astype (np .float32 ),
28
+ )
29
+ np .save (
30
+ join (tmpdir , f"embed_{ i } .npy" ), np .random .randint (0 , 100 , size = n_atoms_i )
26
31
)
27
- np .save (join (tmpdir , f"embed_{ i } .npy" ), np .random .randint (0 , 100 , size = n_atoms_i ))
28
32
if energy :
29
33
np .save (
30
34
join (tmpdir , f"energy_{ i } .npy" ),
@@ -41,6 +45,7 @@ def write_sample_npy_files(energy, forces, tmpdir, num_files):
41
45
n_atoms_per_sample = np .array (n_atoms_per_sample )
42
46
return n_atoms_per_sample
43
47
48
+
44
49
@mark .parametrize ("energy" , [True , False ])
45
50
@mark .parametrize ("forces" , [True , False ])
46
51
@mark .parametrize ("num_files" , [1 , 3 ])
@@ -78,12 +83,22 @@ def test_custom(energy, forces, num_files, preload, tmpdir):
78
83
# Assert shapes of whole dataset:
79
84
for i in range (len (data )):
80
85
n_atoms_i = n_atoms_per_sample [i ]
81
- assert np .array (data [i ].z ).shape == (n_atoms_i ,), "Dataset has incorrect atom numbers shape"
82
- assert np .array (data [i ].pos ).shape == (n_atoms_i , 3 ), "Dataset has incorrect coords shape"
86
+ assert np .array (data [i ].z ).shape == (
87
+ n_atoms_i ,
88
+ ), "Dataset has incorrect atom numbers shape"
89
+ assert np .array (data [i ].pos ).shape == (
90
+ n_atoms_i ,
91
+ 3 ,
92
+ ), "Dataset has incorrect coords shape"
83
93
if energy :
84
- assert np .array (data [i ].y ).shape == (1 ,), "Dataset has incorrect energy shape"
94
+ assert np .array (data [i ].y ).shape == (
95
+ 1 ,
96
+ ), "Dataset has incorrect energy shape"
85
97
if forces :
86
- assert np .array (data [i ].neg_dy ).shape == (n_atoms_i , 3 ), "Dataset has incorrect forces shape"
98
+ assert np .array (data [i ].neg_dy ).shape == (
99
+ n_atoms_i ,
100
+ 3 ,
101
+ ), "Dataset has incorrect forces shape"
87
102
# Assert sample has the correct values
88
103
89
104
# get the reference values from coords_0.npy and embed_0.npy
@@ -98,18 +113,19 @@ def test_custom(energy, forces, num_files, preload, tmpdir):
98
113
ref_neg_dy = np .load (join (tmpdir , "forces_0.npy" ))[0 ] if forces else None
99
114
assert np .allclose (sample .neg_dy , ref_neg_dy ), "Sample has incorrect forces"
100
115
116
+
101
117
@mark .parametrize (("energy" , "forces" ), [(True , False ), (False , True ), (True , True )])
102
118
def test_write_as_hdf5 (energy , forces , tmpdir ):
103
119
# set up necessary files
104
120
num_files = 3
105
121
write_sample_npy_files (energy , forces , tmpdir , num_files )
106
- files = {}
107
- files ["pos" ]= sorted (glob .glob (join (tmpdir , "coords*" )))
108
- files ["z" ]= sorted (glob .glob (join (tmpdir , "embed*" )))
122
+ files = {}
123
+ files ["pos" ] = sorted (glob .glob (join (tmpdir , "coords*" )))
124
+ files ["z" ] = sorted (glob .glob (join (tmpdir , "embed*" )))
109
125
if energy :
110
- files ["y" ]= sorted (glob .glob (join (tmpdir , "energy*" )))
126
+ files ["y" ] = sorted (glob .glob (join (tmpdir , "energy*" )))
111
127
if forces :
112
- files ["neg_dy" ]= sorted (glob .glob (join (tmpdir , "forces*" )))
128
+ files ["neg_dy" ] = sorted (glob .glob (join (tmpdir , "forces*" )))
113
129
write_as_hdf5 (files , join (tmpdir , "test.hdf5" ))
114
130
# Assert file is present in the disk
115
131
assert os .path .isfile (join (tmpdir , "test.hdf5" )), "HDF5 file was not created"
@@ -120,31 +136,45 @@ def test_write_as_hdf5(energy, forces, tmpdir):
120
136
pos_npy = np .load (files ["pos" ][i ])
121
137
n_samples = pos_npy .shape [0 ]
122
138
n_atoms_i = pos_npy .shape [1 ]
123
- assert np .array (data [str (i )]["types" ]).shape == (n_samples , n_atoms_i ,), "Dataset has incorrect atom numbers shape"
124
- assert np .array (data [str (i )]["pos" ]).shape == (n_samples , n_atoms_i , 3 ), "Dataset has incorrect coords shape"
139
+ assert np .array (data [str (i )]["types" ]).shape == (
140
+ n_samples ,
141
+ n_atoms_i ,
142
+ ), "Dataset has incorrect atom numbers shape"
143
+ assert np .array (data [str (i )]["pos" ]).shape == (
144
+ n_samples ,
145
+ n_atoms_i ,
146
+ 3 ,
147
+ ), "Dataset has incorrect coords shape"
125
148
if energy :
126
- assert np .array (data [str (i )]["energy" ]).shape == (n_samples , 1 ,), "Dataset has incorrect energy shape"
149
+ assert np .array (data [str (i )]["energy" ]).shape == (
150
+ n_samples ,
151
+ 1 ,
152
+ ), "Dataset has incorrect energy shape"
127
153
if forces :
128
- assert np .array (data [str (i )]["forces" ]).shape == (n_samples , n_atoms_i , 3 ), "Dataset has incorrect forces shape"
154
+ assert np .array (data [str (i )]["forces" ]).shape == (
155
+ n_samples ,
156
+ n_atoms_i ,
157
+ 3 ,
158
+ ), "Dataset has incorrect forces shape"
159
+
129
160
130
161
@mark .parametrize ("preload" , [True , False ])
131
162
@mark .parametrize (("energy" , "forces" ), [(True , False ), (False , True ), (True , True )])
132
163
@mark .parametrize ("num_files" , [1 , 3 ])
133
164
def test_hdf5 (preload , energy , forces , num_files , tmpdir ):
134
165
# set up necessary files
135
166
n_atoms_per_sample = write_sample_npy_files (energy , forces , tmpdir , num_files )
136
- files = {}
137
- files ["pos" ]= sorted (glob .glob (join (tmpdir , "coords*" )))
138
- files ["z" ]= sorted (glob .glob (join (tmpdir , "embed*" )))
167
+ files = {}
168
+ files ["pos" ] = sorted (glob .glob (join (tmpdir , "coords*" )))
169
+ files ["z" ] = sorted (glob .glob (join (tmpdir , "embed*" )))
139
170
if energy :
140
- files ["y" ]= sorted (glob .glob (join (tmpdir , "energy*" )))
171
+ files ["y" ] = sorted (glob .glob (join (tmpdir , "energy*" )))
141
172
if forces :
142
- files ["neg_dy" ]= sorted (glob .glob (join (tmpdir , "forces*" )))
173
+ files ["neg_dy" ] = sorted (glob .glob (join (tmpdir , "forces*" )))
143
174
write_as_hdf5 (files , join (tmpdir , "test.hdf5" ))
144
175
# Assert file is present in the disk
145
176
assert os .path .isfile (join (tmpdir , "test.hdf5" )), "HDF5 file was not created"
146
177
147
-
148
178
data = HDF5 (join (tmpdir , "test.hdf5" ), dataset_preload_limit = 256 if preload else 0 )
149
179
150
180
assert len (data ) == len (n_atoms_per_sample ), "Number of samples does not match"
@@ -159,12 +189,22 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir):
159
189
# Assert shapes of whole dataset:
160
190
for i in range (len (data )):
161
191
n_atoms_i = n_atoms_per_sample [i ]
162
- assert np .array (data [i ].z ).shape == (n_atoms_i ,), "Dataset has incorrect atom numbers shape"
163
- assert np .array (data [i ].pos ).shape == (n_atoms_i , 3 ), "Dataset has incorrect coords shape"
192
+ assert np .array (data [i ].z ).shape == (
193
+ n_atoms_i ,
194
+ ), "Dataset has incorrect atom numbers shape"
195
+ assert np .array (data [i ].pos ).shape == (
196
+ n_atoms_i ,
197
+ 3 ,
198
+ ), "Dataset has incorrect coords shape"
164
199
if energy :
165
- assert np .array (data [i ].y ).shape == (1 ,), "Dataset has incorrect energy shape"
200
+ assert np .array (data [i ].y ).shape == (
201
+ 1 ,
202
+ ), "Dataset has incorrect energy shape"
166
203
if forces :
167
- assert np .array (data [i ].neg_dy ).shape == (n_atoms_i , 3 ), "Dataset has incorrect forces shape"
204
+ assert np .array (data [i ].neg_dy ).shape == (
205
+ n_atoms_i ,
206
+ 3 ,
207
+ ), "Dataset has incorrect forces shape"
168
208
# Assert sample has the correct values
169
209
# get the reference values from coords_0.npy and embed_0.npy
170
210
ref_pos = np .load (join (tmpdir , "coords_0.npy" ))[0 ]
@@ -179,7 +219,6 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir):
179
219
assert np .allclose (sample .neg_dy , ref_neg_dy ), "Sample has incorrect forces"
180
220
181
221
182
-
183
222
def test_hdf5_multiprocessing (tmpdir , num_entries = 100 ):
184
223
# generate sample data
185
224
z = np .zeros (num_entries )
@@ -202,3 +241,59 @@ def test_hdf5_multiprocessing(tmpdir, num_entries=100):
202
241
dset = HDF5 (join (tmpdir , "test_hdf5_multiprocessing.h5" ))
203
242
204
243
assert len (proc .open_files ()) == n_open , "creating the dataset object opened a file"
244
+
245
+
246
+ def test_ace (tmpdir ):
247
+ # Test Version 1.0
248
+ tmpfilename = join (tmpdir , "molecule.h5" )
249
+ f = h5py .File (tmpfilename , "w" )
250
+ f .attrs ["layout" ] = "Ace"
251
+ f .attrs ["layout_version" ] = "1.0"
252
+ f .attrs ["name" ] = "sample_molecule_data"
253
+ for m in range (3 ): # Three molecules
254
+ mol = f .create_group (f"mol_{ m + 1 } " )
255
+ mol ["atomic_numbers" ] = [1 , 6 , 8 ] # H, C, O
256
+ mol ["formal_charges" ] = [0 , 0 , 0 ] # Neutral charges
257
+ confs = mol .create_group ("conformations" )
258
+ for i in range (2 ): # Two conformations
259
+ conf = confs .create_group (f"conf_{ i + 1 } " )
260
+ conf ["positions" ] = np .random .random ((3 , 3 ))
261
+ conf ["positions" ].attrs ["units" ] = "Å"
262
+ conf ["formation_energy" ] = np .random .random ()
263
+ conf ["formation_energy" ].attrs ["units" ] = "eV"
264
+ conf ["forces" ] = np .random .random ((3 , 3 ))
265
+ conf ["forces" ].attrs ["units" ] = "eV/Å"
266
+ conf ["partial_charges" ] = np .random .random (3 )
267
+ conf ["partial_charges" ].attrs ["units" ] = "e"
268
+ conf ["dipole_moment" ] = np .random .random (3 )
269
+ conf ["dipole_moment" ].attrs ["units" ] = "e*Å"
270
+
271
+ dataset = Ace (root = tmpdir , paths = tmpfilename )
272
+ assert len (dataset ) == 6
273
+ f .flush ()
274
+ f .close ()
275
+ # Test Version 2.0
276
+ tmpfilename_v2 = join (tmpdir , "molecule_v2.h5" )
277
+ f2 = h5py .File (tmpfilename_v2 , "w" )
278
+ f2 .attrs ["layout" ] = "Ace"
279
+ f2 .attrs ["layout_version" ] = "2.0"
280
+ f2 .attrs ["name" ] = "sample_molecule_data_v2"
281
+ master_mol_group = f2 .create_group ("master_molecule_group" )
282
+ for m in range (3 ): # Three molecules
283
+ mol = master_mol_group .create_group (f"mol_{ m + 1 } " )
284
+ mol ["atomic_numbers" ] = [1 , 6 , 8 ] # H, C, O
285
+ mol ["formal_charges" ] = [0 , 0 , 0 ] # Neutral charges
286
+ mol ["positions" ] = np .random .random ((2 , 3 , 3 )) # Two conformations
287
+ mol ["positions" ].attrs ["units" ] = "Å"
288
+ mol ["formation_energies" ] = np .random .random (2 )
289
+ mol ["formation_energies" ].attrs ["units" ] = "eV"
290
+ mol ["forces" ] = np .random .random ((2 , 3 , 3 ))
291
+ mol ["forces" ].attrs ["units" ] = "eV/Å"
292
+ mol ["partial_charges" ] = np .random .random ((2 , 3 ))
293
+ mol ["partial_charges" ].attrs ["units" ] = "e"
294
+ mol ["dipole_moment" ] = np .random .random ((2 , 3 ))
295
+ mol ["dipole_moment" ].attrs ["units" ] = "e*Å"
296
+ dataset_v2 = Ace (root = tmpdir , paths = tmpfilename_v2 )
297
+ assert len (dataset_v2 ) == 6
298
+ f2 .flush ()
299
+ f2 .close ()
0 commit comments