2
2
import torch
3
3
from torch_struct import LogSemiring
4
4
import itertools
5
+ from hypothesis .strategies import integers , composite , floats
6
+ from hypothesis .extra .numpy import arrays
7
+ import numpy as np
5
8
6
9
7
10
class LinearChainTest :
8
- def __init__ (self , semiring = LogSemiring ):
9
- self .semiring = semiring
10
-
11
11
@staticmethod
12
- def _rand (min_n = 2 ):
13
- b = torch .randint (2 , 4 , (1 ,))
14
- N = torch .randint (min_n , 4 , (1 ,))
15
- C = torch .randint (2 , 4 , (1 ,))
16
- return torch .rand (b , N , C , C ), (b .item (), (N + 1 ).item ())
12
+ @composite
13
+ def logpotentials (draw , min_n = 2 ):
14
+ b = draw (integers (min_value = 2 , max_value = 3 ))
15
+ N = draw (integers (min_value = min_n , max_value = 3 ))
16
+ C = draw (integers (min_value = 2 , max_value = 3 ))
17
+ logp = draw (
18
+ arrays (np .float , (b , N , C , C ), floats (min_value = - 100.0 , max_value = 100.0 ))
19
+ )
20
+ return torch .tensor (logp ), (b , (N + 1 ))
17
21
18
22
### Tests
19
-
20
- def enumerate (self , edge , lengths = None ):
21
- model = torch_struct .LinearChain (self . semiring )
22
- semiring = self . semiring
23
+ @ staticmethod
24
+ def enumerate (semiring , edge , lengths = None ):
25
+ model = torch_struct .LinearChain (semiring )
26
+ semiring = semiring
23
27
ssize = semiring .size ()
24
28
edge , batch , N , C , lengths = model ._check_potentials (edge , lengths )
25
29
chains = [[([c ], semiring .one_ (torch .zeros (ssize , batch ))) for c in range (C )]]
@@ -66,17 +70,18 @@ def enumerate(self, edge, lengths=None):
66
70
67
71
68
72
class DepTreeTest :
69
- def __init__ (self , semiring = LogSemiring ):
70
- self .semiring = semiring
71
-
72
73
@staticmethod
73
- def _rand ():
74
- b = torch .randint (2 , 4 , (1 ,))
75
- N = torch .randint (2 , 4 , (1 ,))
76
- return torch .rand (b , N , N ), (b .item (), N .item ())
74
+ @composite
75
+ def logpotentials (draw ):
76
+ b = draw (integers (min_value = 2 , max_value = 3 ))
77
+ N = draw (integers (min_value = 2 , max_value = 3 ))
78
+ logp = draw (
79
+ arrays (np .float , (b , N , N ), floats (min_value = - 10.0 , max_value = 10.0 ))
80
+ )
81
+ return torch .tensor (logp ), (b , N )
77
82
78
- def enumerate ( self , arc_scores , non_proj = False , multi_root = True ):
79
- semiring = self . semiring
83
+ @ staticmethod
84
+ def enumerate ( semiring , arc_scores , non_proj = False , multi_root = True ):
80
85
parses = []
81
86
q = []
82
87
arc_scores = torch_struct .convert (arc_scores )
@@ -101,21 +106,23 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True):
101
106
102
107
103
108
class SemiMarkovTest :
104
- def __init__ (self , semiring = LogSemiring ):
105
- self .semiring = semiring
106
109
107
110
# Tests
108
111
109
112
@staticmethod
110
- def _rand ():
111
- b = torch .randint (2 , 4 , (1 ,))
112
- N = torch .randint (2 , 4 , (1 ,))
113
- K = torch .randint (2 , 4 , (1 ,))
114
- C = torch .randint (2 , 4 , (1 ,))
115
- return torch .rand (b , N , K , C , C ), (b .item (), (N + 1 ).item ())
113
+ @composite
114
+ def logpotentials (draw ):
115
+ b = draw (integers (min_value = 2 , max_value = 3 ))
116
+ N = draw (integers (min_value = 2 , max_value = 3 ))
117
+ K = draw (integers (min_value = 2 , max_value = 3 ))
118
+ C = draw (integers (min_value = 2 , max_value = 3 ))
119
+ logp = draw (
120
+ arrays (np .float , (b , N , K , C , C ), floats (min_value = - 100.0 , max_value = 100.0 ))
121
+ )
122
+ return torch .tensor (logp ), (b , (N + 1 ))
116
123
117
- def enumerate ( self , edge ):
118
- semiring = self . semiring
124
+ @ staticmethod
125
+ def enumerate ( semiring , edge ):
119
126
ssize = semiring .size ()
120
127
batch , N , K , C , _ = edge .shape
121
128
edge = semiring .convert (edge )
@@ -213,12 +220,22 @@ def _is_projective(parse):
213
220
214
221
215
222
class CKY_CRFTest :
216
- def __init__ (self , semiring = LogSemiring ):
217
- self .semiring = semiring
223
+ @staticmethod
224
+ @composite
225
+ def logpotentials (draw ):
226
+ batch = draw (integers (min_value = 2 , max_value = 4 ))
227
+ N = draw (integers (min_value = 2 , max_value = 4 ))
228
+ NT = draw (integers (min_value = 2 , max_value = 4 ))
229
+ logp = draw (
230
+ arrays (
231
+ np .float , (batch , N , N , NT ), floats (min_value = - 100.0 , max_value = 100.0 )
232
+ )
233
+ )
234
+ return torch .tensor (logp ), (batch , N )
218
235
219
- # For testing
220
- def enumerate (self , scores ):
221
- semiring = self . semiring
236
+ @ staticmethod
237
+ def enumerate (semiring , scores ):
238
+ semiring = semiring
222
239
batch , N , _ , NT = scores .shape
223
240
224
241
def enumerate (x , start , end ):
@@ -243,22 +260,36 @@ def enumerate(x, start, end):
243
260
244
261
return semiring .sum (torch .stack (ls , dim = - 1 )), None
245
262
246
- @staticmethod
247
- def _rand ():
248
- batch = torch .randint (2 , 5 , (1 ,))
249
- N = torch .randint (2 , 5 , (1 ,))
250
- NT = torch .randint (2 , 5 , (1 ,))
251
- scores = torch .rand (batch , N , N , NT )
252
- return scores , (batch .item (), N .item ())
253
-
254
263
255
264
class CKYTest :
256
- def __init__ (self , semiring = LogSemiring ):
257
- self .semiring = semiring
265
+ @staticmethod
266
+ @composite
267
+ def logpotentials (draw ):
268
+ batch = draw (integers (min_value = 2 , max_value = 3 ))
269
+ N = draw (integers (min_value = 2 , max_value = 4 ))
270
+ NT = draw (integers (min_value = 2 , max_value = 3 ))
271
+ T = draw (integers (min_value = 2 , max_value = 3 ))
272
+ terms = draw (
273
+ arrays (np .float , (batch , N , T ), floats (min_value = - 100.0 , max_value = 100.0 ))
274
+ )
275
+ rules = draw (
276
+ arrays (
277
+ np .float ,
278
+ (batch , NT , NT + T , NT + T ),
279
+ floats (min_value = - 100.0 , max_value = 100.0 ),
280
+ )
281
+ )
282
+ roots = draw (
283
+ arrays (np .float , (batch , NT ), floats (min_value = - 100.0 , max_value = 100.0 ))
284
+ )
285
+ return (torch .tensor (terms ), torch .tensor (rules ), torch .tensor (roots )), (
286
+ batch ,
287
+ N ,
288
+ )
258
289
259
- def enumerate (self , scores ):
290
+ @staticmethod
291
+ def enumerate (semiring , scores ):
260
292
terms , rules , roots = scores
261
- semiring = self .semiring
262
293
batch , N , T = terms .shape
263
294
_ , NT , _ , _ = rules .shape
264
295
@@ -283,17 +314,6 @@ def enumerate(x, start, end):
283
314
ls += [semiring .times (s , roots [:, nt ]) for s , _ in enumerate (nt , 0 , N )]
284
315
return semiring .sum (torch .stack (ls , dim = - 1 )), None
285
316
286
- @staticmethod
287
- def _rand ():
288
- batch = torch .randint (2 , 5 , (1 ,))
289
- N = torch .randint (2 , 5 , (1 ,))
290
- NT = torch .randint (2 , 5 , (1 ,))
291
- T = torch .randint (2 , 5 , (1 ,))
292
- terms = torch .rand (batch , N , T )
293
- rules = torch .rand (batch , NT , (NT + T ), (NT + T ))
294
- roots = torch .rand (batch , NT )
295
- return (terms , rules , roots ), (batch .item (), N .item ())
296
-
297
317
298
318
class AlignmentTest :
299
319
def __init__ (self , semiring = LogSemiring ):
0 commit comments