diff --git a/tests/extensions.py b/tests/extensions.py index de5b09a0..48ab35c0 100644 --- a/tests/extensions.py +++ b/tests/extensions.py @@ -2,24 +2,28 @@ import torch from torch_struct import LogSemiring import itertools +from hypothesis.strategies import integers, composite, floats +from hypothesis.extra.numpy import arrays +import numpy as np class LinearChainTest: - def __init__(self, semiring=LogSemiring): - self.semiring = semiring - @staticmethod - def _rand(min_n=2): - b = torch.randint(2, 4, (1,)) - N = torch.randint(min_n, 4, (1,)) - C = torch.randint(2, 4, (1,)) - return torch.rand(b, N, C, C), (b.item(), (N + 1).item()) + @composite + def logpotentials(draw, min_n=2): + b = draw(integers(min_value=2, max_value=3)) + N = draw(integers(min_value=min_n, max_value=3)) + C = draw(integers(min_value=2, max_value=3)) + logp = draw( + arrays(np.float, (b, N, C, C), floats(min_value=-100.0, max_value=100.0)) + ) + return torch.tensor(logp), (b, (N + 1)) ### Tests - - def enumerate(self, edge, lengths=None): - model = torch_struct.LinearChain(self.semiring) - semiring = self.semiring + @staticmethod + def enumerate(semiring, edge, lengths=None): + model = torch_struct.LinearChain(semiring) + semiring = semiring ssize = semiring.size() edge, batch, N, C, lengths = model._check_potentials(edge, lengths) chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]] @@ -66,17 +70,18 @@ def enumerate(self, edge, lengths=None): class DepTreeTest: - def __init__(self, semiring=LogSemiring): - self.semiring = semiring - @staticmethod - def _rand(): - b = torch.randint(2, 4, (1,)) - N = torch.randint(2, 4, (1,)) - return torch.rand(b, N, N), (b.item(), N.item()) + @composite + def logpotentials(draw): + b = draw(integers(min_value=2, max_value=3)) + N = draw(integers(min_value=2, max_value=3)) + logp = draw( + arrays(np.float, (b, N, N), floats(min_value=-10.0, max_value=10.0)) + ) + return torch.tensor(logp), (b, N) - def enumerate(self, arc_scores, non_proj=False, multi_root=True): - semiring = self.semiring + @staticmethod + def enumerate(semiring, arc_scores, non_proj=False, multi_root=True): parses = [] q = [] arc_scores = torch_struct.convert(arc_scores) @@ -101,21 +106,23 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True): class SemiMarkovTest: - def __init__(self, semiring=LogSemiring): - self.semiring = semiring # Tests @staticmethod - def _rand(): - b = torch.randint(2, 4, (1,)) - N = torch.randint(2, 4, (1,)) - K = torch.randint(2, 4, (1,)) - C = torch.randint(2, 4, (1,)) - return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item()) + @composite + def logpotentials(draw): + b = draw(integers(min_value=2, max_value=3)) + N = draw(integers(min_value=2, max_value=3)) + K = draw(integers(min_value=2, max_value=3)) + C = draw(integers(min_value=2, max_value=3)) + logp = draw( + arrays(np.float, (b, N, K, C, C), floats(min_value=-100.0, max_value=100.0)) + ) + return torch.tensor(logp), (b, (N + 1)) - def enumerate(self, edge): - semiring = self.semiring + @staticmethod + def enumerate(semiring, edge): ssize = semiring.size() batch, N, K, C, _ = edge.shape edge = semiring.convert(edge) @@ -213,12 +220,22 @@ def _is_projective(parse): class CKY_CRFTest: - def __init__(self, semiring=LogSemiring): - self.semiring = semiring + @staticmethod + @composite + def logpotentials(draw): + batch = draw(integers(min_value=2, max_value=4)) + N = draw(integers(min_value=2, max_value=4)) + NT = draw(integers(min_value=2, max_value=4)) + logp = draw( + arrays( + np.float, (batch, N, N, NT), floats(min_value=-100.0, max_value=100.0) + ) + ) + return torch.tensor(logp), (batch, N) - # For testing - def enumerate(self, scores): - semiring = self.semiring + @staticmethod + def enumerate(semiring, scores): + semiring = semiring batch, N, _, NT = scores.shape def enumerate(x, start, end): @@ -243,22 +260,36 @@ def enumerate(x, start, end): return semiring.sum(torch.stack(ls, dim=-1)), None - @staticmethod - def _rand(): - batch = torch.randint(2, 5, (1,)) - N = torch.randint(2, 5, (1,)) - NT = torch.randint(2, 5, (1,)) - scores = torch.rand(batch, N, N, NT) - return scores, (batch.item(), N.item()) - class CKYTest: - def __init__(self, semiring=LogSemiring): - self.semiring = semiring + @staticmethod + @composite + def logpotentials(draw): + batch = draw(integers(min_value=2, max_value=3)) + N = draw(integers(min_value=2, max_value=4)) + NT = draw(integers(min_value=2, max_value=3)) + T = draw(integers(min_value=2, max_value=3)) + terms = draw( + arrays(np.float, (batch, N, T), floats(min_value=-100.0, max_value=100.0)) + ) + rules = draw( + arrays( + np.float, + (batch, NT, NT + T, NT + T), + floats(min_value=-100.0, max_value=100.0), + ) + ) + roots = draw( + arrays(np.float, (batch, NT), floats(min_value=-100.0, max_value=100.0)) + ) + return (torch.tensor(terms), torch.tensor(rules), torch.tensor(roots)), ( + batch, + N, + ) - def enumerate(self, scores): + @staticmethod + def enumerate(semiring, scores): terms, rules, roots = scores - semiring = self.semiring batch, N, T = terms.shape _, NT, _, _ = rules.shape @@ -283,17 +314,6 @@ def enumerate(x, start, end): ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)] return semiring.sum(torch.stack(ls, dim=-1)), None - @staticmethod - def _rand(): - batch = torch.randint(2, 5, (1,)) - N = torch.randint(2, 5, (1,)) - NT = torch.randint(2, 5, (1,)) - T = torch.randint(2, 5, (1,)) - terms = torch.rand(batch, N, T) - rules = torch.rand(batch, NT, (NT + T), (NT + T)) - roots = torch.rand(batch, NT) - return (terms, rules, roots), (batch.item(), N.item()) - class AlignmentTest: def __init__(self, semiring=LogSemiring): diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index b31acae8..ec40e031 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,281 +1,259 @@ -from torch_struct import CKY, CKY_CRF, DepTree, LinearChain, SemiMarkov, Alignment +from torch_struct import ( + CKY, + CKY_CRF, + DepTree, + LinearChain, + SemiMarkov, + Alignment, + deptree_nonproj, + deptree_part, +) from torch_struct import ( LogSemiring, CheckpointSemiring, CheckpointShardSemiring, - GumbelCRFSemiring, KMaxSemiring, SparseMaxSemiring, MaxSemiring, StdSemiring, - SampledSemiring, EntropySemiring, - MultiSampledSemiring, ) -from .extensions import test_lookup +from .extensions import ( + LinearChainTest, + SemiMarkovTest, + DepTreeTest, + CKYTest, + CKY_CRFTest, + test_lookup, +) import torch -from hypothesis import given, settings +from hypothesis import given from hypothesis.strategies import integers, data, sampled_from +import pytest + +from hypothesis import settings + +settings.register_profile("ci", max_examples=50, deadline=None) + +settings.load_profile("ci") + smint = integers(min_value=2, max_value=4) tint = integers(min_value=1, max_value=2) lint = integers(min_value=2, max_value=10) +algorithms = { + "LinearChain": (LinearChain, LinearChainTest), + "SemiMarkov": (SemiMarkov, SemiMarkovTest), + "Dep": (DepTree, DepTreeTest), + "CKY_CRF": (CKY_CRF, CKY_CRFTest), + "CKY": (CKY, CKYTest), +} + + +class Gen: + "Helper class for tests" + + def __init__(self, model_test, data, semiring): + model_test = algorithms[model_test] + self.data = data + self.model = model_test[0] + self.struct = self.model(semiring) + self.test = model_test[1] + self.vals, (self.batch, self.N) = data.draw(self.test.logpotentials()) + # jitter + if not isinstance(self.vals, tuple): + self.vals = self.vals + 1e-6 * torch.rand(*self.vals.shape) + self.semiring = semiring + + def enum(self, semiring=None): + return self.test.enumerate( + semiring if semiring is not None else self.semiring, self.vals + ) + + +# Model specific tests. + + @given(smint, smint, smint) @settings(max_examples=50, deadline=None) -def test_simple_a(batch, N, C): +def test_linear_chain_counting(batch, N, C): vals = torch.ones(batch, N, C, C) semiring = StdSemiring alpha = LinearChain(semiring).sum(vals) c = pow(C, N + 1) - print(c) assert (alpha == c).all() - LinearChain(SampledSemiring).marginals(vals) - - LinearChain(MultiSampledSemiring).marginals(vals) -@given(smint, smint, smint, smint) -@settings(max_examples=50, deadline=None) -def test_simple_b(batch, N, K, C): - print(N) - N = 14 - vals = torch.ones(batch, N, 5, C, C) - SemiMarkov(SampledSemiring).marginals(vals) - SemiMarkov(MultiSampledSemiring).marginals(vals) - - -# @given(data()) -# @settings(max_examples=50, deadline=None) -# def test_networkx(data): -# batch = 5 -# N = 10 -# NT = 5 -# T = 5 - -# torch.manual_seed(0) - -# terms = torch.rand(batch, N, T) -# rules = torch.rand(batch, NT, (NT + T), (NT + T)) -# roots = torch.rand(batch, NT) -# vals = (terms, rules, roots) -# model = CKY -# lengths = torch.tensor( -# [data.draw(integers(min_value=3, max_value=N)) for b in range(batch - 1)] + [N] -# ) -# struct = model(SampledSemiring) -# marginals = struct.marginals(vals, lengths=lengths) -# spans = CKY.from_parts(marginals)[0] -# CKY.to_networkx(spans) - -# struct = model(MultiSampledSemiring) -# marginals = struct.marginals(vals, lengths=lengths) -# m2 = tuple((MultiSampledSemiring.to_discrete(m, 5) for m in marginals)) -# spans = CKY.from_parts(m2)[0] -# CKY.to_networkx(spans) +# Semiring tests @given(data()) -def test_entropy(data): - model = data.draw(sampled_from([LinearChain, SemiMarkov])) - semiring = EntropySemiring - struct = model(semiring) - test = test_lookup[model](LogSemiring) - vals, (batch, N) = test._rand() - alpha = struct.sum(vals) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep"]) +@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring]) +def test_log_shapes(model_test, semiring, data): + gen = Gen(model_test, data, semiring) + alpha = gen.struct.sum(gen.vals) + count = gen.enum()[0] + + assert alpha.shape[0] == gen.batch + assert count.shape[0] == gen.batch + assert alpha.shape == count.shape + assert torch.isclose(count[0], alpha[0]) - log_z = model(LogSemiring).sum(vals) - log_probs = test.enumerate(vals)[1] +@given(data()) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov"]) +def test_entropy(model_test, data): + "Test entropy by manual enumeration" + gen = Gen(model_test, data, EntropySemiring) + alpha = gen.struct.sum(gen.vals) + log_z = gen.model(LogSemiring).sum(gen.vals) + + log_probs = gen.enum(LogSemiring)[1] log_probs = torch.stack(log_probs, dim=1) - log_z - print(log_probs.shape, log_z.shape, log_probs.exp().sum(1)) entropy = -log_probs.mul(log_probs.exp()).sum(1).squeeze(0) assert entropy.shape == alpha.shape assert torch.isclose(entropy, alpha).all() @given(data()) -def test_kmax(data): - model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree])) +@pytest.mark.parametrize("model_test", ["LinearChain"]) +def test_sparse_max(model_test, data): + gen = Gen(model_test, data, SparseMaxSemiring) + gen.vals.requires_grad_(True) + gen.struct.sum(gen.vals) + sparsemax = gen.struct.marginals(gen.vals) + sparsemax.sum().backward() + + +@given(data()) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep"]) +def test_kmax(model_test, data): + "Test out the k-max semiring" K = 2 - semiring = KMaxSemiring(K) - struct = model(semiring) - test = test_lookup[model](LogSemiring) - vals, (batch, N) = test._rand() - max1 = model(MaxSemiring).sum(vals) - alpha = struct.sum(vals, _raw=True) + gen = Gen(model_test, data, KMaxSemiring(K)) + max1 = gen.model(MaxSemiring).sum(gen.vals) + alpha = gen.struct.sum(gen.vals, _raw=True) + + # 2max is less than max. assert (alpha[0] == max1).all() assert (alpha[1] <= max1).all() - topk = struct.marginals(vals, _raw=True) - argmax = model(MaxSemiring).marginals(vals) + topk = gen.struct.marginals(gen.vals, _raw=True) + argmax = gen.model(MaxSemiring).marginals(gen.vals) + + # Argmax is different than 2-argmax assert (topk[0] == argmax).all() - print(topk[0].nonzero(), topk[1].nonzero()) assert (topk[1] != topk[0]).any() - if model != DepTree: - log_probs = test_lookup[model](MaxSemiring).enumerate(vals)[1] + if model_test != "Dep": + log_probs = gen.enum(MaxSemiring)[1] tops = torch.topk(torch.cat(log_probs, dim=0), 5, 0)[0] - assert torch.isclose(struct.score(topk[1], vals), alpha[1]).all() + assert torch.isclose(gen.struct.score(topk[1], gen.vals), alpha[1]).all() for k in range(K): assert (torch.isclose(alpha[k], tops[k])).all() @given(data()) -@settings(max_examples=50, deadline=None) -def test_cky(data): - model = data.draw(sampled_from([CKY])) - semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) - struct = model(semiring) - test = test_lookup[model](semiring) - vals, (batch, N) = test._rand() - alpha = struct.sum(vals) - count = test.enumerate(vals)[0] - - assert alpha.shape[0] == batch - assert count.shape[0] == batch +@pytest.mark.parametrize("model_test", ["CKY"]) +@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring]) +def test_cky(model_test, semiring, data): + gen = Gen(model_test, data, semiring) + alpha = gen.struct.sum(gen.vals) + count = gen.enum()[0] + + assert alpha.shape[0] == gen.batch + assert count.shape[0] == gen.batch assert alpha.shape == count.shape assert torch.isclose(count[0], alpha[0]) @given(data()) -@settings(max_examples=50, deadline=None) -def test_generic_a(data): - model = data.draw( - sampled_from( - [SemiMarkov] - ) # , Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree]) - ) - - semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) - struct = model(semiring) - test = test_lookup[model](semiring) - vals, (batch, N) = test._rand() - alpha = struct.sum(vals) - count = test.enumerate(vals)[0] - # assert(False) - assert alpha.shape[0] == batch - assert count.shape[0] == batch - assert alpha.shape == count.shape - assert torch.isclose(count[0], alpha[0]) - - vals, _ = test._rand() - struct = model(MaxSemiring) - score = struct.sum(vals) - marginals = struct.marginals(vals) - # print(marginals) - # # assert(False) - assert torch.isclose(score, struct.score(vals, marginals)).all() +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "CKY_CRF", "Dep"]) +def test_max(model_test, data): + "Test that argmax score is the same as max" + gen = Gen(model_test, data, MaxSemiring) + score = gen.struct.sum(gen.vals) + marginals = gen.struct.marginals(gen.vals) + assert torch.isclose(score, gen.struct.score(gen.vals, marginals)).all() @given(data()) -@settings(max_examples=50, deadline=None) -def test_labeled_proj_deptree(data): - semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) - struct = DepTree(semiring) +@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring]) +@pytest.mark.parametrize("model_test", ["Dep"]) +def test_labeled_proj_deptree(model_test, semiring, data): + gen = Gen(model_test, data, semiring) + arc_scores = torch.rand(3, 5, 5, 7) - count = test_lookup[DepTree](semiring).enumerate(semiring.sum(arc_scores))[0] - alpha = struct.sum(arc_scores) + gen.vals = semiring.sum(arc_scores) + count = gen.enum()[0] + alpha = gen.struct.sum(arc_scores) assert torch.isclose(count, alpha).all() - struct = DepTree(MaxSemiring) + struct = gen.model(MaxSemiring) max_score = struct.sum(arc_scores) argmax = struct.marginals(arc_scores) assert torch.isclose(max_score, struct.score(arc_scores, argmax)).all() -# @given(data()) -# @settings(max_examples=50, deadline=None) -# def test_non_proj(data): -# model = data.draw(sampled_from([DepTree])) -# semiring = data.draw(sampled_from([LogSemiring])) -# struct = model(semiring) -# vals, (batch, N) = model._rand() -# alpha = deptree_part(vals) -# count = struct.enumerate(vals, non_proj=True, multi_root=False)[0] - -# assert alpha.shape[0] == batch -# assert count.shape[0] == batch -# assert alpha.shape == count.shape -# assert torch.isclose(count[0], alpha[0]) - -# marginals = deptree_nonproj(vals) -# print(marginals.sum(1)) -# # assert(False) -# # vals, _ = model._rand() -# # struct = model(MaxSemiring) -# # score = struct.sum(vals) -# # marginals = struct.marginals(vals) -# # assert torch.isclose(score, struct.score(vals, marginals)).all() - - -@given(data(), integers(min_value=1, max_value=20)) -def test_parts_from_marginals(data, seed): - # todo: add CKY, DepTree too? - model = data.draw(sampled_from([LinearChain, SemiMarkov])) - test = test_lookup[model]() - torch.manual_seed(seed) - vals, (batch, N) = test._rand() - - edge = model(MaxSemiring).marginals(vals).long() +# todo: add CKY, DepTree too? +@given(data()) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep", "CKY_CRF"]) +def test_parts_from_marginals(model_test, data): + gen = Gen(model_test, data, MaxSemiring) - sequence, extra = model.from_parts(edge) - edge_ = model.to_parts(sequence, extra) + edge = gen.struct.marginals(gen.vals).long() + sequence, extra = gen.model.from_parts(edge) + edge_ = gen.model.to_parts(sequence, extra) assert (torch.isclose(edge, edge_)).all(), edge - edge_ - sequence_, extra_ = model.from_parts(edge_) + sequence_, extra_ = gen.model.from_parts(edge_) assert extra == extra_, (extra, extra_) - assert (torch.isclose(sequence, sequence_)).all(), sequence - sequence_ -@given(data(), integers(min_value=1, max_value=20)) -def test_parts_from_sequence(data, seed): - model = data.draw(sampled_from([LinearChain, SemiMarkov])) - struct = model() - test = test_lookup[model]() - torch.manual_seed(seed) - vals, (batch, N) = test._rand() - C = vals.size(-1) - if isinstance(struct, LinearChain): +@given(data()) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov"]) +def test_parts_from_sequence(model_test, data): + gen = Gen(model_test, data, LogSemiring) + C = gen.vals.size(-1) + if isinstance(gen.struct, LinearChain): K = 2 background = 0 extra = C - elif isinstance(struct, SemiMarkov): - K = vals.size(-3) + elif isinstance(gen.struct, SemiMarkov): + K = gen.vals.size(-3) background = -1 extra = C, K else: raise NotImplementedError() - sequence = torch.full((batch, N), background, dtype=int) - for b in range(batch): + sequence = torch.full((gen.batch, gen.N), background, dtype=int) + for b in range(gen.batch): i = 0 - while i < N: + while i < gen.N: symbol = torch.randint(0, C, (1,)).item() sequence[b, i] = symbol length = torch.randint(1, K, (1,)).item() i += length - edge = model.to_parts(sequence, extra) - sequence_, extra_ = model.from_parts(edge) + edge = gen.model.to_parts(sequence, extra) + sequence_, extra_ = gen.model.from_parts(edge) assert extra == extra_, (extra, extra_) assert (torch.isclose(sequence, sequence_)).all(), sequence - sequence_ - edge_ = model.to_parts(sequence_, extra_) + edge_ = gen.model.to_parts(sequence_, extra_) assert (torch.isclose(edge, edge_)).all(), edge - edge_ -@given(data(), integers(min_value=1, max_value=10)) -@settings(max_examples=50, deadline=None) -def test_generic_lengths(data, seed): - model = data.draw(sampled_from([CKY, LinearChain, SemiMarkov, CKY_CRF, DepTree])) - struct = model() - torch.manual_seed(seed) - test = test_lookup[model]() - vals, (batch, N) = test._rand() +@given(data()) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "CKY_CRF", "Dep"]) +def test_generic_lengths(model_test, data): + gen = Gen(model_test, data, LogSemiring) + model, struct, vals, N, batch = gen.model, gen.struct, gen.vals, gen.N, gen.batch lengths = torch.tensor( [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] ) @@ -283,146 +261,49 @@ def test_generic_lengths(data, seed): m = model(MaxSemiring).marginals(vals, lengths=lengths) maxes = struct.score(vals, m) part = model().sum(vals, lengths=lengths) - print(maxes, part) + + # Check that max is correct assert (maxes <= part).all() m_part = model(MaxSemiring).sum(vals, lengths=lengths) assert (torch.isclose(maxes, m_part)).all(), maxes - m_part - # m2 = deptree(vals, lengths=lengths) - # assert (m2 < part).all() - if model == CKY: return seqs, extra = struct.from_parts(m) - # assert (seqs.shape == (batch, N)) - # assert seqs.max().item() <= N full = struct.to_parts(seqs, extra, lengths=lengths) - if isinstance(full, tuple): - for i in range(len(full)): - if i == 1: - p = m[i].sum(1).sum(1) - else: - p = m[i] - assert (full[i] == p.type_as(full[i])).all(), "%s %s %s" % ( - i, - full[i].nonzero(), - p.nonzero(), - ) - else: - assert (full == m.type_as(full)).all(), "%s %s %s" % ( - full.shape, - m.shape, - (full - m.type_as(full)).nonzero(), - ) + assert (full == m.type_as(full)).all(), "%s %s %s" % ( + full.shape, + m.shape, + (full - m.type_as(full)).nonzero(), + ) -@settings(max_examples=50, deadline=None) -@given(data(), integers(min_value=1, max_value=10)) -def test_params(data, seed): - model = data.draw(sampled_from([DepTree, SemiMarkov, DepTree, CKY, CKY_CRF])) - torch.manual_seed(seed) - test = test_lookup[model]() - vals, (batch, N) = test._rand() +@given(data()) +@pytest.mark.parametrize( + "model_test", ["LinearChain", "SemiMarkov", "Dep", "CKY", "CKY_CRF"] +) +def test_params(model_test, data): + gen = Gen(model_test, data, LogSemiring) + _, struct, vals, _, _ = gen.model, gen.struct, gen.vals, gen.N, gen.batch + if isinstance(vals, tuple): vals = tuple((v.requires_grad_(True) for v in vals)) else: vals.requires_grad_(True) - # torch.autograd.set_detect_anomaly(True) - semiring = LogSemiring - alpha = model(semiring).sum(vals) + alpha = struct.sum(vals) alpha.sum().backward() - if not isinstance(vals, tuple): - b = vals.grad.detach() - vals.grad.zero_() - alpha = model(semiring).sum(vals, _autograd=False) - alpha.sum().backward() - c = vals.grad.detach() - assert torch.isclose(b, c).all() - @given(data()) -@settings(max_examples=50, deadline=None) -def ignore_alignment(data): - - # log_potentials = torch.ones(2, 2, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - # print("FINAL", v) - # log_potentials = torch.ones(2, 3, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - # print("FINAL", v) - - # log_potentials = torch.ones(2, 6, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - # print("FINAL", v) - - # log_potentials = torch.ones(2, 7, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - # print("FINAL", v) - - # log_potentials = torch.ones(2, 8, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - # print("FINAL", v) - # assert False - - # model = data.draw(sampled_from([Alignment])) - # semiring = data.draw(sampled_from([StdSemiring])) - # struct = model(semiring) - # vals, (batch, N) = model._rand() - # print(batch, N) - # struct = model(semiring) - # # , max_gap=max(3, abs(vals.shape[1] - vals.shape[2]) + 1)) - # vals.fill_(1) - # alpha = struct.sum(vals) - - model = data.draw(sampled_from([Alignment])) - semiring = data.draw(sampled_from([StdSemiring])) - test = test_lookup[model](semiring) - struct = model(semiring, sparse_rounds=10) - vals, (batch, N) = test._rand() - alpha = struct.sum(vals) - count = test.enumerate(vals)[0] - assert torch.isclose(count, alpha).all() - - model = data.draw(sampled_from([Alignment])) - semiring = data.draw(sampled_from([LogSemiring])) - struct = model(semiring, sparse_rounds=10) - vals, (batch, N) = model._rand() - alpha = struct.sum(vals) - count = test_lookup[model](semiring).enumerate(vals)[0] - assert torch.isclose(count, alpha).all() - - # model = data.draw(sampled_from([Alignment])) - # semiring = data.draw(sampled_from([MaxSemiring])) - # struct = model(semiring) - # log_potentials = torch.ones(2, 2, 2, 3) - # v = Alignment(StdSemiring).sum(log_potentials) - - log_potentials = torch.ones(2, 2, 8, 3) - v = Alignment(MaxSemiring).sum(log_potentials) - # print(v) - # assert False - m = Alignment(MaxSemiring).marginals(log_potentials) - score = Alignment(MaxSemiring).score(log_potentials, m) - assert torch.isclose(v, score).all() - - semiring = data.draw(sampled_from([MaxSemiring])) - struct = model(semiring, local=True) - test = test_lookup[model](semiring) - vals, (batch, N) = test._rand() - vals[..., 0] = -2 * vals[..., 0].abs() - vals[..., 1] = vals[..., 1].abs() - vals[..., 2] = -2 * vals[..., 2].abs() - alpha = struct.sum(vals) - count = test.enumerate(vals)[0] - mx = struct.marginals(vals) - print(alpha, count) - print(mx[0].nonzero()) - # assert torch.isclose(count, alpha).all() - struct = model(semiring, max_gap=1) - alpha = struct.sum(vals) +@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep"]) +def test_gumbel(model_test, data): + gen = Gen(model_test, data, LogSemiring) + gen.vals.requires_grad_(True) + alpha = gen.struct.marginals(gen.vals) + print(alpha[0]) + print(torch.autograd.grad(alpha, gen.vals, alpha.detach())[0][0]) def test_hmm(): @@ -435,19 +316,6 @@ def test_hmm(): LinearChain().sum(out) -@given(data()) -def test_sparse_max(data): - model = data.draw(sampled_from([LinearChain])) - semiring = SparseMaxSemiring - test = test_lookup[model]() - vals, (batch, N) = test._rand() - vals.requires_grad_(True) - model(semiring).sum(vals) - sparsemax = model(semiring).marginals(vals) - print(vals.requires_grad) - sparsemax.sum().backward() - - def test_sparse_max2(): print(LinearChain(SparseMaxSemiring).sum(torch.rand(1, 8, 3, 3))) print(LinearChain(SparseMaxSemiring).marginals(torch.rand(1, 8, 3, 3))) @@ -515,13 +383,119 @@ def test_lc_custom(): @given(data()) -def test_gumbel(data): - model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree])) - semiring = GumbelCRFSemiring(1.0) - test = test_lookup[model]() - struct = model(semiring) +@pytest.mark.parametrize("model_test", ["Dep"]) +@pytest.mark.parametrize("semiring", [LogSemiring]) +def test_non_proj(model_test, semiring, data): + gen = Gen(model_test, data, semiring) + alpha = deptree_part(gen.vals, False) + count = gen.test.enumerate(LogSemiring, gen.vals, non_proj=True, multi_root=False)[ + 0 + ] + + assert alpha.shape[0] == gen.batch + assert count.shape[0] == gen.batch + assert alpha.shape == count.shape + # assert torch.isclose(count[0], alpha[0], 1e-2) + + alpha = deptree_part(gen.vals, True) + count = gen.test.enumerate(LogSemiring, gen.vals, non_proj=True, multi_root=True)[0] + + assert alpha.shape[0] == gen.batch + assert count.shape[0] == gen.batch + assert alpha.shape == count.shape + # assert torch.isclose(count[0], alpha[0], 1e-2) + + marginals = deptree_nonproj(gen.vals, multi_root=False) + print(marginals.sum(1)) + marginals = deptree_nonproj(gen.vals, multi_root=True) + print(marginals.sum(1)) + + +# # assert(False) +# # vals, _ = model._rand() +# # struct = model(MaxSemiring) +# # score = struct.sum(vals) +# # marginals = struct.marginals(vals) +# # assert torch.isclose(score, struct.score(vals, marginals)).all() + + +@given(data()) +@settings(max_examples=50, deadline=None) +def ignore_alignment(data): + + # log_potentials = torch.ones(2, 2, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + # print("FINAL", v) + # log_potentials = torch.ones(2, 3, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + # print("FINAL", v) + + # log_potentials = torch.ones(2, 6, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + # print("FINAL", v) + + # log_potentials = torch.ones(2, 7, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + # print("FINAL", v) + + # log_potentials = torch.ones(2, 8, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + # print("FINAL", v) + # assert False + + # model = data.draw(sampled_from([Alignment])) + # semiring = data.draw(sampled_from([StdSemiring])) + # struct = model(semiring) + # vals, (batch, N) = model._rand() + # print(batch, N) + # struct = model(semiring) + # # , max_gap=max(3, abs(vals.shape[1] - vals.shape[2]) + 1)) + # vals.fill_(1) + # alpha = struct.sum(vals) + + model = data.draw(sampled_from([Alignment])) + semiring = data.draw(sampled_from([StdSemiring])) + test = test_lookup[model](semiring) + struct = model(semiring, sparse_rounds=10) vals, (batch, N) = test._rand() - vals.requires_grad_(True) - alpha = struct.marginals(vals) - print(alpha[0]) - print(torch.autograd.grad(alpha, vals, alpha.detach())[0][0]) + alpha = struct.sum(vals) + count = test.enumerate(vals)[0] + assert torch.isclose(count, alpha).all() + + model = data.draw(sampled_from([Alignment])) + semiring = data.draw(sampled_from([LogSemiring])) + struct = model(semiring, sparse_rounds=10) + vals, (batch, N) = model._rand() + alpha = struct.sum(vals) + count = test_lookup[model](semiring).enumerate(vals)[0] + assert torch.isclose(count, alpha).all() + + # model = data.draw(sampled_from([Alignment])) + # semiring = data.draw(sampled_from([MaxSemiring])) + # struct = model(semiring) + # log_potentials = torch.ones(2, 2, 2, 3) + # v = Alignment(StdSemiring).sum(log_potentials) + + log_potentials = torch.ones(2, 2, 8, 3) + v = Alignment(MaxSemiring).sum(log_potentials) + # print(v) + # assert False + m = Alignment(MaxSemiring).marginals(log_potentials) + score = Alignment(MaxSemiring).score(log_potentials, m) + assert torch.isclose(v, score).all() + + semiring = data.draw(sampled_from([MaxSemiring])) + struct = model(semiring, local=True) + test = test_lookup[model](semiring) + vals, (batch, N) = test._rand() + vals[..., 0] = -2 * vals[..., 0].abs() + vals[..., 1] = vals[..., 1].abs() + vals[..., 2] = -2 * vals[..., 2].abs() + alpha = struct.sum(vals) + count = test.enumerate(vals)[0] + mx = struct.marginals(vals) + print(alpha, count) + print(mx[0].nonzero()) + # assert torch.isclose(count, alpha).all() + struct = model(semiring, max_gap=1) + alpha = struct.sum(vals) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 4111520d..0e5f8ad3 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,4 +1,4 @@ -from torch_struct import LinearChainCRF, Autoregressive, KMaxSemiring +from torch_struct import LinearChainCRF, Autoregressive, KMaxSemiring, LogSemiring import torch from hypothesis import given, settings from hypothesis.strategies import integers, data, sampled_from @@ -20,7 +20,7 @@ def enumerate_support(dist): (enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*) """ _, _, edges, enum_lengths = test_lookup[dist.struct]().enumerate( - dist.log_potentials, dist.lengths + LogSemiring, dist.log_potentials, dist.lengths ) # if expand: # edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:]) diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index 075033a8..71bb271c 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -45,7 +45,7 @@ def _check_potentials(self, edge, lengths=None): assert max(lengths) == N, "One length must be at least N" return edge, batch, N, M, lengths - def _dp(self, log_potentials, lengths=None, force_grad=False, cache=True): + def logparition(self, log_potentials, lengths=None, force_grad=False, cache=True): return self._dp_scan(log_potentials, lengths, force_grad) def _dp_scan(self, log_potentials, lengths=None, force_grad=False): @@ -192,4 +192,4 @@ def pad(v): v = chart[ ..., 0, Open, Open, Mid, N - 1, M - N + ((chart.shape[-1] - 1) // 2) ] - return v, [log_potentials], None + return v, [log_potentials] diff --git a/torch_struct/cky.py b/torch_struct/cky.py index 03b50100..60b4cc14 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -5,7 +5,7 @@ class CKY(_Struct): - def _dp(self, scores, lengths=None, force_grad=False): + def logpartition(self, scores, lengths=None, force_grad=False): semiring = self.semiring @@ -76,7 +76,7 @@ def arr(a, b): final = beta[A][0, :, NTs] top = torch.stack([final[:, i, l - 1] for i, l in enumerate(lengths)], dim=1) log_Z = semiring.dot(top, roots) - return log_Z, (term_use, rules, roots, span[1:]), beta + return log_Z, (term_use, rules, roots, span[1:]) def marginals(self, scores, lengths=None, _autograd=True, _raw=False): """ @@ -97,7 +97,7 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False): batch, N, T = terms.shape _, NT, _, _ = rules.shape - v, (term_use, rule_use, root_use, spans), alpha = self._dp( + v, (term_use, rule_use, root_use, spans) = self.logpartition( scores, lengths=lengths, force_grad=True ) diff --git a/torch_struct/cky_crf.py b/torch_struct/cky_crf.py index c06badbc..f9e16a12 100644 --- a/torch_struct/cky_crf.py +++ b/torch_struct/cky_crf.py @@ -13,7 +13,7 @@ def _check_potentials(self, edge, lengths=None): return edge, batch, N, NT, lengths - def _dp(self, scores, lengths=None, force_grad=False): + def logpartition(self, scores, lengths=None, force_grad=False): semiring = self.semiring scores, batch, N, NT, lengths = self._check_potentials(scores, lengths) @@ -40,4 +40,4 @@ def _dp(self, scores, lengths=None, force_grad=False): final = beta[A][0, :] log_Z = final[:, torch.arange(batch), lengths - 1] - return log_Z, [scores], beta + return log_Z, [scores] diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index d6a8fe0d..c8cb4baa 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -46,7 +46,7 @@ class DepTree(_Struct): Note: For single-root case, do not set cache=True for now. """ - def _dp(self, arc_scores_in, lengths=None, force_grad=False): + def logpartition(self, arc_scores_in, lengths=None, force_grad=False): multiroot = getattr(self, "multiroot", True) if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)") @@ -109,7 +109,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False): final = alpha[A][C][R][(0,)] v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1) - return v, [arc_scores_in], alpha + return v, [arc_scores_in] def _check_potentials(self, arc_scores, lengths=None): semiring = self.semiring @@ -174,7 +174,7 @@ def from_parts(arcs): return labels, None -def deptree_part(arc_scores, multi_root, lengths, eps=1e-5): +def deptree_part(arc_scores, multi_root, lengths=None, eps=1e-5): if lengths is not None: batch, N, N = arc_scores.shape x = torch.arange(N, device=arc_scores.device).expand(batch, N) @@ -205,7 +205,7 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5): return lap.logdet() -def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5): +def deptree_nonproj(arc_scores, multi_root, lengths=None, eps=1e-5): """ Compute the marginals of a non-projective dependency tree using the matrix-tree theorem. diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 78997d77..3b7c0a1a 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -1,7 +1,6 @@ import torch import math from .semirings import LogSemiring -from torch.autograd import Function class Chart: @@ -64,49 +63,24 @@ def _make_chart(self, N, size, potentials, force_grad=False): for _ in range(N) ] - def sum(self, logpotentials, lengths=None, _autograd=True, _raw=False): + def sum(self, logpotentials, lengths=None, _raw=False): """ Compute the (semiring) sum over all structures model. Parameters: logpotentials : generic params (see class) lengths: None or b long tensor mask + _raw (bool) : return the unconverted semiring Returns: v: b tensor of total sum """ + v = self.logpartition(logpotentials, lengths)[0] + if _raw: + return v + return self.semiring.unconvert(v) - if ( - _autograd - or self.semiring is not LogSemiring - or not hasattr(self, "_dp_backward") - ): - - v = self._dp(logpotentials, lengths)[0] - if _raw: - return v - return self.semiring.unconvert(v) - - else: - v, _, alpha = self._dp(logpotentials, lengths, False) - - class DPManual(Function): - @staticmethod - def forward(ctx, input): - return v - - @staticmethod - def backward(ctx, grad_v): - marginals = self._dp_backward(logpotentials, lengths, alpha) - return marginals.mul( - grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim())) - ) - - return DPManual.apply(logpotentials) - - def marginals( - self, logpotentials, lengths=None, _autograd=True, _raw=False, _combine=False - ): + def marginals(self, logpotentials, lengths=None, _raw=False): """ Compute the marginals of a structured model. @@ -118,43 +92,28 @@ def marginals( marginals: b x (N-1) x C x C table """ - if ( - _autograd - or self.semiring is not LogSemiring - or not hasattr(self, "_dp_backward") - ): - v, edges, _ = self._dp(logpotentials, lengths=lengths, force_grad=True) - if _raw: - all_m = [] - for k in range(v.shape[0]): - obj = v[k].sum(dim=0) - - marg = torch.autograd.grad( - obj, - edges, - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) - return torch.stack(all_m, dim=0) - elif _combine: - obj = v.sum(dim=0).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) - a_m = self._arrange_marginals(marg) - return a_m - else: - obj = self.semiring.unconvert(v).sum(dim=0) + v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) + if _raw: + all_m = [] + for k in range(v.shape[0]): + obj = v[k].sum(dim=0) + marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False + obj, + edges, + create_graph=True, + only_inputs=True, + allow_unused=False, ) - a_m = self._arrange_marginals(marg) - return self.semiring.unconvert(a_m) + all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) + return torch.stack(all_m, dim=0) else: - v, _, alpha = self._dp(logpotentials, lengths=lengths, force_grad=True) - return self._dp_backward(logpotentials, lengths, alpha) + obj = self.semiring.unconvert(v).sum(dim=0) + marg = torch.autograd.grad( + obj, edges, create_graph=True, only_inputs=True, allow_unused=False + ) + a_m = self._arrange_marginals(marg) + return self.semiring.unconvert(a_m) @staticmethod def to_parts(spans, extra, lengths=None): diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 787ab586..593b2404 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -41,10 +41,7 @@ def _check_potentials(self, edge, lengths=None): assert C == C2, "Transition shape doesn't match" return edge, batch, N, C, lengths - def _dp(self, log_potentials, lengths=None, force_grad=False): - return self._dp_scan(log_potentials, lengths, force_grad) - - def _dp_scan(self, log_potentials, lengths=None, force_grad=False): + def logpartition(self, log_potentials, lengths=None, force_grad=False): "Compute forward pass by linear scan" # Setup semiring = self.semiring @@ -83,7 +80,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): for n in range(1, log_N + 1): chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2]) v = semiring.sum(semiring.sum(chart[:, :, 0].contiguous())) - return v, [log_potentials], None + return v, [log_potentials] @staticmethod def to_parts(sequence, extra, lengths=None): diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index ca4c0bc2..2e802c5a 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -18,7 +18,7 @@ def _check_potentials(self, edge, lengths=None): assert C == C2, "Transition shape doesn't match" return edge, batch, N, K, C, lengths - def _dp(self, log_potentials, lengths=None, force_grad=False, cache=True): + def logpartition(self, log_potentials, lengths=None, force_grad=False): "Compute forward pass by linear scan" # Setup @@ -79,7 +79,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False, cache=True): final = chart.view(-1, batch, K_1, C, K_1, C) v = semiring.sum(semiring.sum(final[:, :, 0, :, 0, :].contiguous())) - return v, [log_potentials], None + return v, [log_potentials] # def _dp_standard(self, edge, lengths=None, force_grad=False): # semiring = self.semiring