Skip to content

Commit d272745

Browse files
authored
Rewrite the tests structure (#92)
* update tests * style
1 parent 95b2c19 commit d272745

File tree

10 files changed

+420
-470
lines changed

10 files changed

+420
-470
lines changed

tests/extensions.py

Lines changed: 79 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,28 @@
22
import torch
33
from torch_struct import LogSemiring
44
import itertools
5+
from hypothesis.strategies import integers, composite, floats
6+
from hypothesis.extra.numpy import arrays
7+
import numpy as np
58

69

710
class LinearChainTest:
8-
def __init__(self, semiring=LogSemiring):
9-
self.semiring = semiring
10-
1111
@staticmethod
12-
def _rand(min_n=2):
13-
b = torch.randint(2, 4, (1,))
14-
N = torch.randint(min_n, 4, (1,))
15-
C = torch.randint(2, 4, (1,))
16-
return torch.rand(b, N, C, C), (b.item(), (N + 1).item())
12+
@composite
13+
def logpotentials(draw, min_n=2):
14+
b = draw(integers(min_value=2, max_value=3))
15+
N = draw(integers(min_value=min_n, max_value=3))
16+
C = draw(integers(min_value=2, max_value=3))
17+
logp = draw(
18+
arrays(np.float, (b, N, C, C), floats(min_value=-100.0, max_value=100.0))
19+
)
20+
return torch.tensor(logp), (b, (N + 1))
1721

1822
### Tests
19-
20-
def enumerate(self, edge, lengths=None):
21-
model = torch_struct.LinearChain(self.semiring)
22-
semiring = self.semiring
23+
@staticmethod
24+
def enumerate(semiring, edge, lengths=None):
25+
model = torch_struct.LinearChain(semiring)
26+
semiring = semiring
2327
ssize = semiring.size()
2428
edge, batch, N, C, lengths = model._check_potentials(edge, lengths)
2529
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]
@@ -66,17 +70,18 @@ def enumerate(self, edge, lengths=None):
6670

6771

6872
class DepTreeTest:
69-
def __init__(self, semiring=LogSemiring):
70-
self.semiring = semiring
71-
7273
@staticmethod
73-
def _rand():
74-
b = torch.randint(2, 4, (1,))
75-
N = torch.randint(2, 4, (1,))
76-
return torch.rand(b, N, N), (b.item(), N.item())
74+
@composite
75+
def logpotentials(draw):
76+
b = draw(integers(min_value=2, max_value=3))
77+
N = draw(integers(min_value=2, max_value=3))
78+
logp = draw(
79+
arrays(np.float, (b, N, N), floats(min_value=-10.0, max_value=10.0))
80+
)
81+
return torch.tensor(logp), (b, N)
7782

78-
def enumerate(self, arc_scores, non_proj=False, multi_root=True):
79-
semiring = self.semiring
83+
@staticmethod
84+
def enumerate(semiring, arc_scores, non_proj=False, multi_root=True):
8085
parses = []
8186
q = []
8287
arc_scores = torch_struct.convert(arc_scores)
@@ -101,21 +106,23 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True):
101106

102107

103108
class SemiMarkovTest:
104-
def __init__(self, semiring=LogSemiring):
105-
self.semiring = semiring
106109

107110
# Tests
108111

109112
@staticmethod
110-
def _rand():
111-
b = torch.randint(2, 4, (1,))
112-
N = torch.randint(2, 4, (1,))
113-
K = torch.randint(2, 4, (1,))
114-
C = torch.randint(2, 4, (1,))
115-
return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item())
113+
@composite
114+
def logpotentials(draw):
115+
b = draw(integers(min_value=2, max_value=3))
116+
N = draw(integers(min_value=2, max_value=3))
117+
K = draw(integers(min_value=2, max_value=3))
118+
C = draw(integers(min_value=2, max_value=3))
119+
logp = draw(
120+
arrays(np.float, (b, N, K, C, C), floats(min_value=-100.0, max_value=100.0))
121+
)
122+
return torch.tensor(logp), (b, (N + 1))
116123

117-
def enumerate(self, edge):
118-
semiring = self.semiring
124+
@staticmethod
125+
def enumerate(semiring, edge):
119126
ssize = semiring.size()
120127
batch, N, K, C, _ = edge.shape
121128
edge = semiring.convert(edge)
@@ -213,12 +220,22 @@ def _is_projective(parse):
213220

214221

215222
class CKY_CRFTest:
216-
def __init__(self, semiring=LogSemiring):
217-
self.semiring = semiring
223+
@staticmethod
224+
@composite
225+
def logpotentials(draw):
226+
batch = draw(integers(min_value=2, max_value=4))
227+
N = draw(integers(min_value=2, max_value=4))
228+
NT = draw(integers(min_value=2, max_value=4))
229+
logp = draw(
230+
arrays(
231+
np.float, (batch, N, N, NT), floats(min_value=-100.0, max_value=100.0)
232+
)
233+
)
234+
return torch.tensor(logp), (batch, N)
218235

219-
# For testing
220-
def enumerate(self, scores):
221-
semiring = self.semiring
236+
@staticmethod
237+
def enumerate(semiring, scores):
238+
semiring = semiring
222239
batch, N, _, NT = scores.shape
223240

224241
def enumerate(x, start, end):
@@ -243,22 +260,36 @@ def enumerate(x, start, end):
243260

244261
return semiring.sum(torch.stack(ls, dim=-1)), None
245262

246-
@staticmethod
247-
def _rand():
248-
batch = torch.randint(2, 5, (1,))
249-
N = torch.randint(2, 5, (1,))
250-
NT = torch.randint(2, 5, (1,))
251-
scores = torch.rand(batch, N, N, NT)
252-
return scores, (batch.item(), N.item())
253-
254263

255264
class CKYTest:
256-
def __init__(self, semiring=LogSemiring):
257-
self.semiring = semiring
265+
@staticmethod
266+
@composite
267+
def logpotentials(draw):
268+
batch = draw(integers(min_value=2, max_value=3))
269+
N = draw(integers(min_value=2, max_value=4))
270+
NT = draw(integers(min_value=2, max_value=3))
271+
T = draw(integers(min_value=2, max_value=3))
272+
terms = draw(
273+
arrays(np.float, (batch, N, T), floats(min_value=-100.0, max_value=100.0))
274+
)
275+
rules = draw(
276+
arrays(
277+
np.float,
278+
(batch, NT, NT + T, NT + T),
279+
floats(min_value=-100.0, max_value=100.0),
280+
)
281+
)
282+
roots = draw(
283+
arrays(np.float, (batch, NT), floats(min_value=-100.0, max_value=100.0))
284+
)
285+
return (torch.tensor(terms), torch.tensor(rules), torch.tensor(roots)), (
286+
batch,
287+
N,
288+
)
258289

259-
def enumerate(self, scores):
290+
@staticmethod
291+
def enumerate(semiring, scores):
260292
terms, rules, roots = scores
261-
semiring = self.semiring
262293
batch, N, T = terms.shape
263294
_, NT, _, _ = rules.shape
264295

@@ -283,17 +314,6 @@ def enumerate(x, start, end):
283314
ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)]
284315
return semiring.sum(torch.stack(ls, dim=-1)), None
285316

286-
@staticmethod
287-
def _rand():
288-
batch = torch.randint(2, 5, (1,))
289-
N = torch.randint(2, 5, (1,))
290-
NT = torch.randint(2, 5, (1,))
291-
T = torch.randint(2, 5, (1,))
292-
terms = torch.rand(batch, N, T)
293-
rules = torch.rand(batch, NT, (NT + T), (NT + T))
294-
roots = torch.rand(batch, NT)
295-
return (terms, rules, roots), (batch.item(), N.item())
296-
297317

298318
class AlignmentTest:
299319
def __init__(self, semiring=LogSemiring):

0 commit comments

Comments
 (0)