From 90fc546f8c727256b12bfe9affcba7ccdbef4e83 Mon Sep 17 00:00:00 2001 From: teffland Date: Fri, 25 Sep 2020 19:22:59 -0400 Subject: [PATCH 01/14] starting on full crf --- torch_struct/distributions.py | 146 +++++++++++++++++++++++++++------- torch_struct/full_cky_crf.py | 85 ++++++++++++++++++++ torch_struct/helpers.py | 59 ++++++++++---- 3 files changed, 246 insertions(+), 44 deletions(-) create mode 100644 torch_struct/full_cky_crf.py diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 2065b250..658873b0 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -7,6 +7,7 @@ from .alignment import Alignment from .deptree import DepTree, deptree_nonproj, deptree_part from .cky_crf import CKY_CRF +from .full_cky_crf import Full_CKY_CRF from .semirings import ( LogSemiring, MaxSemiring, @@ -91,9 +92,7 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) def kl(self, other): """ @@ -105,9 +104,7 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) @lazy_property def max(self): @@ -140,9 +137,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -155,9 +150,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -186,9 +179,7 @@ def count(self): def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( - self.log_potentials, self.lengths - ) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) return st_gumbel # @constraints.dependent_property @@ -204,29 +195,104 @@ def partition(self): "Compute the log-partition function." return self._struct(LogSemiring).sum(self.log_potentials, self.lengths) - def sample(self, sample_shape=torch.Size()): + def sample(self, sample_shape=torch.Size(), batch_size=10): r""" Compute structured samples from the distribution :math:`z \sim p(z)`. Parameters: sample_shape (int): number of samples + batch_size (int): number of samples to compute at a time Returns: samples (*sample_shape x batch_shape x event_shape*) """ - assert len(sample_shape) == 1 - nsamples = sample_shape[0] + if type(sample_shape) == int: + nsamples = sample_shape + else: + assert len(sample_shape) == 1 + nsamples = sample_shape[0] samples = [] for k in range(nsamples): - if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + if k % batch_size == 0: + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() - tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) + tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) samples.append(tmp_sample) return torch.stack(samples) + def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10): + r""" + Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)` + + This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples. + As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to + a deterministic distribution that is always uniform over all values. + + The approximation empirically causes a "heavy-hitting" bias where a few configurations are more likely than normal + at the expense of many others, making the tail effectively longer. There is evidence however that temps closer + to 1 reduce this somewhat by smoothing the distribution. + + Parameters: + sample_shape (int): number of samples + temp (float): (default=1.0) relaxation temperature + sample_batch_size (int): size of batches to calculates samples + + Returns: + samples (*sample_shape x batch_shape x event_shape*) + + """ + if type(sample_shape) == int: + nsamples = sample_shape + else: + assert len(sample_shape) == 1 + nsamples = sample_shape[0] + if sample_batch_size > nsamples: + sample_batch_size = nsamples + samples = [] + + if noise_shape is None: + noise_shape = self.log_potentials.shape[1:] + + # print(noise) + assert len(noise_shape) == len(self.log_potentials.shape[1:]) + assert all( + s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:]) + ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}" + + for k in range(nsamples): + if k % sample_batch_size == 0: + shape = self.log_potentials.shape + B = shape[0] + s_log_potentials = ( + self.log_potentials.reshape(1, *shape) + .repeat(sample_batch_size, *tuple(1 for _ in shape)) + .reshape(-1, *shape[1:]) + ) + + s_lengths = self.lengths + if s_lengths is not None: + s_shape = s_lengths.shape + s_lengths = ( + s_lengths.reshape(1, *s_shape) + .repeat(sample_batch_size, *tuple(1 for _ in s_shape)) + .reshape(-1, *s_shape[1:]) + ) + + noise = ( + torch.distributions.Gumbel(0, 1) + .sample((sample_batch_size * B, *noise_shape)) + .expand_as(s_log_potentials) + ) + noisy_potentials = (s_log_potentials + noise) / temp + + r_sample = ( + self._struct(LogSemiring) + .marginals(noisy_potentials, s_lengths) + .reshape(sample_batch_size, B, *shape[1:]) + ) + samples.append(r_sample) + return torch.cat(samples, dim=0)[:nsamples] + def to_event(self, sequence, extra, lengths=None): "Convert simple representation to event." return self.struct.to_parts(sequence, extra, lengths=None) @@ -301,9 +367,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -411,6 +475,32 @@ class TreeCRF(StructDistribution): struct = CKY_CRF +class FullTreeCRF(StructDistribution): + r""" + Represents a 1st-order span parser with NT nonterminals. Implemented using a + fast CKY algorithm. + + For a description see: + + * Inside-Outside Algorithm, by Michael Collins + + Event shape is of the form: + + Parameters: + log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. + :math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` + lengths (long tensor) : batch shape integers for length masking. + + Implementation uses width-batched, forward-pass only + + * Parallel Time: :math:`O(N)` parallel merges. + * Forward Memory: :math:`O(N^2)` + + Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) + """ + struct = Full_CKY_CRF + + class SentCFG(StructDistribution): """ Represents a full generative context-free grammar with @@ -440,9 +530,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py new file mode 100644 index 00000000..81c2c048 --- /dev/null +++ b/torch_struct/full_cky_crf.py @@ -0,0 +1,85 @@ +import torch +from .helpers import _Struct, Chart + +A, B = 0, 1 + + +class Full_CKY_CRF(_Struct): + def _check_potentials(self, edge, lengths=None): + batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) + assert ( + N == N1 == N2 and NT == NT1 == NT2 + ), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" + edge = self.semiring.convert(edge) + semiring_shape = edge.shape[:-7] + if lengths is None: + lengths = torch.LongTensor([N] * batch).to(edge.device) + + return edge, semiring_shape, batch, N, NT, lengths + + def _dp(self, scores, lengths=None, force_grad=False, cache=True): + semiring = self.semiring + scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) + # scores.shape = *sshape, B, N, N, N, NT, NT, NT + + beta = [Chart((batch, N, N), scores, semiring, cache=cache) for _ in range(2)] + L_DIM, R_DIM = len(sshape) + 1, len(sshape) + 2 # usually 2,3 + + # Initialize + reduced_scores = semiring.sum(scores) + print(reduced_scores.shape) + term = reduced_scores.diagonal(0, L_DIM, R_DIM) + bs, ns, nts = torch.arange(batch), torch.arange(N), torch.arange(NT) + term_scores = scores[:, bs, ns, ns, ns, nts, nts, nts] + print(term_scores.shape) + beta[A][ns, 0] = term + beta[B][ns, N - 1] = term + + # Run + for w in range(1, N): + left = slice(None, N - w) + right = slice(w, None) + Y = beta[A][left, :w] + Z = beta[B][right, N - w :] + score = reduced_scores.diagonal(w, L_DIM, R_DIM) + new = semiring.times(semiring.dot(Y, Z), score) + beta[A][left, w] = new + beta[B][right, N - w - 1] = new + + final = beta[A][0, :] + log_Z = final[:, torch.arange(batch), lengths - 1] + return log_Z, [scores], beta + + # For testing + + def enumerate(self, scores, lengths=None): + semiring = self.semiring + batch, N, _, _, NT, _, _ = scores.shape + + def enumerate(x, start, end): + if start + 1 == end: + yield (scores[:, start, start, x], [(start, x)]) + else: + for w in range(start + 1, end): + for y in range(NT): + for z in range(NT): + for m1, y1 in enumerate(y, start, w): + for m2, z1 in enumerate(z, w, end): + yield ( + semiring.times(m1, m2, scores[:, start, end - 1, x]), + [(x, start, w, end)] + y1 + z1, + ) + + ls = [] + for nt in range(NT): + ls += [s for s, _ in enumerate(nt, 0, N)] + + return semiring.sum(torch.stack(ls, dim=-1)), None + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + scores = torch.rand(batch, N, N, N, NT, NT, NT) + return scores, (batch.item(), N.item()) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 3b7c0a1a..7260a518 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -6,11 +6,7 @@ class Chart: def __init__(self, size, potentials, semiring): self.data = semiring.zero_( - torch.zeros( - *((semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ) self.grad = self.data.detach().clone().fill_(0.0) @@ -24,15 +20,45 @@ def __setitem__(self, ind, new): class _Struct: + """`_Struct` is base class used to represent the graphical structure of a model. + + Subclasses should implement a `_dp` method which computes the partition function (under the standard `_BaseSemiring`). + Different `StructDistribution` methods will instantiate the `_Struct` subclasses + """ + def __init__(self, semiring=LogSemiring): self.semiring = semiring + def _dp(self, scores, lengths=None, force_grad=False, cache=True): + """Implement computation equivalent to the computing partition constant Z (if self.semiring == `_BaseSemiring`). + + Params: + scores: torch.FloatTensor, log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) + lengths: torch.LongTensor = None, lengths of batch padded examples. Shape = ( * x batch size ) + force_grad: bool = False + cache: bool = True + + Returns: + v: torch.Tensor, the resulting output of the dynammic program + edges: List[torch.Tensor], the log edge potentials of the model. + When `scores` is already in a log_potential format for the distribution (typical), this will be + [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. + An exceptional case is the `CKY` struct, which takes log potential parameters from production rules + for a PCFG, which are by definition independent of position in the sequence. + charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the + "backward" dynamic program and compute things like marginals w/o autograd. + + """ + raise NotImplementedError + def score(self, potentials, parts, batch_dims=[0]): - score = torch.mul(potentials, parts) + """Score for entire structure is product of potentials for all activated "parts".""" + score = torch.mul(potentials, parts) # mask potentials by activated "parts" batch = tuple((score.shape[b] for b in batch_dims)) - return self.semiring.prod(score.view(batch + (-1,))) + return self.semiring.prod(score.view(batch + (-1,))) # product of all potentialsa def _bin_length(self, length): + """Find least upper bound for lengths that is a power of 2. Used in parallel scans.""" log_N = int(math.ceil(math.log(length, 2))) bin_N = int(math.pow(2, log_N)) return log_N, bin_N @@ -53,11 +79,7 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros( - *((self.semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -109,9 +131,7 @@ def marginals(self, logpotentials, lengths=None, _raw=False): return torch.stack(all_m, dim=0) else: obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) + marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) a_m = self._arrange_marginals(marg) return self.semiring.unconvert(a_m) @@ -125,3 +145,12 @@ def from_parts(spans): def _arrange_marginals(self, marg): return marg[0] + + # For Testing + def _rand(self, *args, **kwargs): + """TODO:""" + raise NotImplementedError + + def enumerate(self, edge, lengths=None): + """TODO:""" + raise NotImplementedError From e863fbbe93ed1659b98dc9d7965efbf1b6992eda Mon Sep 17 00:00:00 2001 From: Thomas Effland Date: Sun, 27 Sep 2020 09:01:54 -0400 Subject: [PATCH 02/14] working full cky crf --- torch_struct/full_cky_crf.py | 149 +++++++++++++++++++++++----- torch_struct/semirings/semirings.py | 12 +-- 2 files changed, 127 insertions(+), 34 deletions(-) diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py index 81c2c048..2f2e277c 100644 --- a/torch_struct/full_cky_crf.py +++ b/torch_struct/full_cky_crf.py @@ -18,41 +18,136 @@ def _check_potentials(self, edge, lengths=None): return edge, semiring_shape, batch, N, NT, lengths def _dp(self, scores, lengths=None, force_grad=False, cache=True): - semiring = self.semiring + sr = self.semiring + # torch.autograd.set_detect_anomaly(True) + + # Scores.shape = *sshape, B, N, N, N, NT, NT, NT + # w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] + # where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) - # scores.shape = *sshape, B, N, N, N, NT, NT, NT - - beta = [Chart((batch, N, N), scores, semiring, cache=cache) for _ in range(2)] - L_DIM, R_DIM = len(sshape) + 1, len(sshape) + 2 # usually 2,3 - - # Initialize - reduced_scores = semiring.sum(scores) - print(reduced_scores.shape) - term = reduced_scores.diagonal(0, L_DIM, R_DIM) - bs, ns, nts = torch.arange(batch), torch.arange(N), torch.arange(NT) - term_scores = scores[:, bs, ns, ns, ns, nts, nts, nts] - print(term_scores.shape) - beta[A][ns, 0] = term - beta[B][ns, N - 1] = term - - # Run + sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] + S, b = len(sdims), batch + + # Initialize data structs + LEFT, RIGHT = 0, 1 + L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim + # Will store sum of subtrees up to i,j,A from the left and right + # beta[LEFT][i,d,A] = sum of potentials of all subtrees in span i,j=(i+d) with nonterminal A + # indexed from the left endpoint i plus the width d + # . = alpha[i,j=(i+d),A] in a nonvectorized version + # beta[RIGHT][j,d',A] = sum of potentials of all subtrees in span i=(j-(N-d')),j with NT A + # indexed from the right endpoint, from widest to shortest subtrees. + # This gets filled in from right to left. + + # OVERRIDE CACHE + cache = False + # print("cache", cache) + beta = [Chart((b, N, N, NT), scores, sr, cache=cache) for _ in range(2)] + + # Initialize the base cases with scores from diagonal i=j=k, A=B=C + term_scores = ( + scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 + .diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 + .diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 + .diagonal(0, -3, -1) # diag of C with that gives A=B=C + ) + assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" + beta[LEFT][:, 0, :] = term_scores + beta[RIGHT][:, N - 1, :] = term_scores + alpha_left = term_scores + alpha_right = term_scores + + ### old: init with semiring's multiplicative identity, gives zeros mass to leaves + # ns = torch.arange(NT) + # beta[LEFT][:, 0, :] = sr.one_(beta[LEFT][:, 0, :]) + # beta[RIGHT][:, N - 1, :] = sr.one_(beta[RIGHT][:, N - 1, :]) + # alpha_left = sr.one_(torch.ones(sshape + [b, N, NT]).to(scores.device)) + # alpha_right = sr.one_(torch.ones(sshape + [b, N, NT]).to(scores.device)) + + alphas = [[alpha_left], [alpha_right]] + + # Run vectorized inside alg for w in range(1, N): - left = slice(None, N - w) - right = slice(w, None) - Y = beta[A][left, :w] - Z = beta[B][right, N - w :] - score = reduced_scores.diagonal(w, L_DIM, R_DIM) - new = semiring.times(semiring.dot(Y, Z), score) - beta[A][left, w] = new - beta[B][right, N - w - 1] = new - - final = beta[A][0, :] + # print("\nw", w, "N-w", N - w) + # Scores + # What we want is a tensor with: + # shape: *sshape, batch, (N-w), NT, w, NT, NT + # w/ semantics: [...batch, (i,j=i+w), A, k, B, C] + # where (i,j=i+w) means the diagonal of trees nodes with width w + # Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] + score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores + # print("diagonal", score.shape[S:]) + + score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front + # print("permute", score.shape[S:]) + score = score[..., :w, :, :] # remove illegal splitpoints + # print("slice", score.shape[S:]) + assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" + # print("S", score[0, 0, :, 0, :, 0, 0].exp()) + # Sums of left subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) + left = slice(None, N - w) # left indices + L1 = beta[LEFT][left, :w] + L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] + + assert L.isclose(L1).all() + # print("L", L.shape) + + # Sums of right subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) + right = slice(w, None) # right indices + R1 = beta[RIGHT][right, N - w :] + R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] + assert R.isclose(R1).all() + # print("R", R.shape) # R[0, 0, :, :, 0].exp()) + + # Broadcast them both to match missing dims in score + # Left B is duplicated for all head and right symbols A C + L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT]) + # Right C is duplicated for all head and left symbols A B + R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1]) + + assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + # print(score.shape[S + 1 :], L_bcast.shape, R_bcast.shape) + + # Now multiply all the scores and sum over k, B, C dimensions (the last three dims) + assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + # print("sum prod w", sum_prod_w.exp()) + assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" + + # new = sr.times(sr.dot(Y, Z), score) + beta[LEFT][left, w] = sum_prod_w + beta[RIGHT][right, N - w - 1] = sum_prod_w + # pad = sr.zero_(torch.ones_like(sum_prod_w))[..., :w, :] + pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) + sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) + sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) + # print(sum_prod_w.shape, sum_prod_w_left.shape, sum_prod_w_right.shape) + alphas[LEFT].append(sum_prod_w_left) + alphas[RIGHT].append(sum_prod_w_right) + # for c in range(NT): + # print(f"left c:{c}\n", beta[LEFT][:, :].exp().detach().numpy()) + + # print(f"right c:{c}\n", beta[RIGHT][:, :].exp().detach().numpy()) + + final1 = sr.sum(beta[LEFT][0, :, :]) + final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol + # print(f"f1:{final1.shape}, f:{final.shape}, ls:{lengths}") + assert final.isclose(final1).all(), f"final:\n{final}\nfinal1:\n{final1}" + + # log_Z = final[..., 0, lengths - 1] log_Z = final[:, torch.arange(batch), lengths - 1] + # log_Z.exp().sum().backward() + # print("Z", log_Z.exp()) return log_Z, [scores], beta # For testing def enumerate(self, scores, lengths=None): + raise NotImplementedError semiring = self.semiring batch, N, _, _, NT, _, _ = scores.shape diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index bb7b9ec1..31a4334f 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -91,7 +91,8 @@ def plus(cls, a, b): class _Base(Semiring): - zero = 0 + zero = 0.0 + one = 1.0 @staticmethod def mul(a, b): @@ -112,6 +113,7 @@ def one_(xs): class _BaseLog(Semiring): zero = -1e9 + one = 1.0 @staticmethod def sum(xs, dim=-1): @@ -308,9 +310,7 @@ def sum(xs, dim=-1): ( part_p, part_q, - torch.sum( - xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d - ), + torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d), ) ) @@ -384,9 +384,7 @@ def sum(xs, dim=-1): log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() - return torch.stack( - (part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)) - ) + return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d))) @staticmethod def mul(a, b): From 498c9648f9dfc973daeb7f6e5c0bdfcb67ae665c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 27 Sep 2020 23:32:19 +0000 Subject: [PATCH 03/14] debugging full cky --- torch_struct/full_cky_crf.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py index 2f2e277c..1bff1c14 100644 --- a/torch_struct/full_cky_crf.py +++ b/torch_struct/full_cky_crf.py @@ -1,5 +1,6 @@ import torch from .helpers import _Struct, Chart +from tqdm import tqdm A, B = 0, 1 @@ -18,6 +19,9 @@ def _check_potentials(self, edge, lengths=None): return edge, semiring_shape, batch, N, NT, lengths def _dp(self, scores, lengths=None, force_grad=False, cache=True): + DEBUG = False + if DEBUG: + print("FullCKYCRF DP starting") sr = self.semiring # torch.autograd.set_detect_anomaly(True) @@ -67,7 +71,9 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): alphas = [[alpha_left], [alpha_right]] # Run vectorized inside alg - for w in range(1, N): + + ws = tqdm(range(1, N), "Calculating marginals at width", N - 1) if DEBUG else range(1, N) + for w in ws: # print("\nw", w, "N-w", N - w) # Scores # What we want is a tensor with: @@ -142,6 +148,8 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): log_Z = final[:, torch.arange(batch), lengths - 1] # log_Z.exp().sum().backward() # print("Z", log_Z.exp()) + # if DEBUG: + # print("Using autograd to get marginals") return log_Z, [scores], beta # For testing From 049dbd4ef988eed7112434c7d8011ebe25920caa Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Dec 2020 19:44:08 +0000 Subject: [PATCH 04/14] more changes --- torch_struct/distributions.py | 2 +- torch_struct/linearchain.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 658873b0..681c6242 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -282,7 +282,7 @@ def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_ torch.distributions.Gumbel(0, 1) .sample((sample_batch_size * B, *noise_shape)) .expand_as(s_log_potentials) - ) + ).to(s_log_potentials.device) noisy_potentials = (s_log_potentials + noise) / temp r_sample = ( diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 593b2404..31547640 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -122,7 +122,7 @@ def from_parts(edge): batch, N_1, C, _ = edge.shape N = N_1 + 1 labels = torch.zeros(batch, N).long() - on = edge.nonzero() + on = edge.nonzero(as_tuple=False) for i in range(on.shape[0]): if on[i][1] == 0: labels[on[i][0], on[i][1]] = on[i][3] From 3c5dfbc80f3a991252023603448fa62cc964e505 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 19 Jan 2021 16:23:42 +0000 Subject: [PATCH 05/14] [wip] add expectation semiring --- torch_struct/distributions.py | 11 +++++ torch_struct/semirings/semirings.py | 63 +++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 681c6242..ebde94f1 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -70,6 +70,17 @@ def log_prob(self, value): return v - self.partition + @lazy_property + def expectation(self, values): + """ + Compute expectation for distribution :math:`E_z[f(z)]` where f that decomposes additively over the parts. + + Returns: + entropy (*batch_shape*) + """ + + return self._struct(ExpectationSemiring).sum([self.log_potentials, values], self.lengths) + @lazy_property def entropy(self): """ diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index 31a4334f..f7333665 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -482,6 +482,69 @@ def one_(xs): return xs +class ExpectationSemiring(Semiring): + """ + Implements an value expectation semiring where the value function decomposes additively over parts + + Based on descriptions in: + + * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` + * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` + """ + + zero = 0 + + @staticmethod + def size(): + return 2 + + @staticmethod + def convert(xs): + values = torch.zeros((2,) + xs.shape).type_as(xs) + values[0] = xs + values[1] = 0 + return values + + @staticmethod + def unconvert(xs): + return xs[1] + + @staticmethod + def sum(xs, dim=-1): + assert dim != 0 + d = dim - 1 if dim > 0 else dim + part = torch.logsumexp(xs[0], dim=d) + log_sm = xs[0] - part.unsqueeze(d) + sm = log_sm.exp() + return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d))) + + @staticmethod + def mul(a, b): + return torch.stack((a[0] + b[0], a[1] + b[1])) + + @classmethod + def prod(cls, xs, dim=-1): + return xs.sum(dim) + + @classmethod + def zero_mask_(cls, xs, mask): + "Fill *ssize x ...* tensor with additive identity." + xs[0].masked_fill_(mask, -1e5) + xs[1].masked_fill_(mask, 0) + + @staticmethod + def zero_(xs): + xs[0].fill_(-1e5) + xs[1].fill_(0) + return xs + + @staticmethod + def one_(xs): + xs[0].fill_(0) + xs[1].fill_(0) + return xs + + def TempMax(alpha): class _TempMax(_BaseLog): """ From 73c8d7b6391f29f99a591049db8a7e1e8a355457 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 00:06:54 +0000 Subject: [PATCH 06/14] add expected value semiring and test --- tests/test_distributions.py | 12 ++ torch_struct/distributions.py | 237 ++++++++++++++++----------- torch_struct/helpers.py | 55 +++---- torch_struct/semirings/semirings.py | 244 +++++++++++++++------------- 4 files changed, 308 insertions(+), 240 deletions(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 0e5f8ad3..8cf847c6 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -76,6 +76,18 @@ def test_simple(data, seed): dist.kmax(5) dist.count + val_func = torch.rand(*vals.shape, 10) + E_val = dist.expected_value(val_func) + struct_vals = ( + edges.unsqueeze(-1) + .mul(val_func.unsqueeze(0)) + .reshape(*edges.shape[:2], -1, val_func.shape[-1]) + .sum(2) + ) + assert torch.isclose( + E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0) + ).all(), "Efficient expected value not equal to enumeration" + @given(data(), integers(min_value=1, max_value=20)) @settings(max_examples=50, deadline=None) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index ebde94f1..cd904594 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -18,6 +18,7 @@ KMaxSemiring, StdSemiring, GumbelCRFSemiring, + ValueExpectationSemiring, ) @@ -70,17 +71,6 @@ def log_prob(self, value): return v - self.partition - @lazy_property - def expectation(self, values): - """ - Compute expectation for distribution :math:`E_z[f(z)]` where f that decomposes additively over the parts. - - Returns: - entropy (*batch_shape*) - """ - - return self._struct(ExpectationSemiring).sum([self.log_potentials, values], self.lengths) - @lazy_property def entropy(self): """ @@ -103,7 +93,9 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(CrossEntropySemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) def kl(self, other): """ @@ -115,7 +107,9 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(KLDivergenceSemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) @lazy_property def max(self): @@ -148,7 +142,9 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).sum( + self.log_potentials, self.lengths, _raw=True + ) def topk(self, k): r""" @@ -161,7 +157,9 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self): @@ -183,14 +181,56 @@ def marginals(self): @lazy_property def count(self): - "Compute the log-partition function." + "Compute the total number of parts in structure with non-zero probability." ones = torch.ones_like(self.log_potentials) ones[self.log_potentials.eq(-float("inf"))] = 0 return self._struct(StdSemiring).sum(ones, self.lengths) + def expected_value(self, values): + """ + Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. + + Params: + * values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part + of the structure. `values` can have 0 or more training dimensions in addition to the `event_shape`, + which allows for computing the expected value of, say, a vector valued function + (or a vector of scalar functions). + Returns: + expected value (*batch_shape, *value_shape) + + + """ + # Handle value function dimensionality + phi_shape = self.log_potentials.shape + extra_dims = len(values.shape) - len(phi_shape) + if extra_dims: + # Extra dims get flattened and put in front + out_val_shape = values.shape[len(phi_shape) :] + values = values.reshape(*phi_shape, -1) + values = values.permute([-1] + list(range(len(phi_shape)))) + k = values.shape[0] + else: + out_val_shape = None + k = 1 + + # Compute expected value + val = self._struct(ValueExpectationSemiring(k)).sum( + [self.log_potentials, values], self.lengths + ) + + # Reformat dimensions to match input dimensions + val = val.permute(list(range(1, len(val.shape))) + [0]) + if out_val_shape is not None: + val = val.reshape(*val.shape[:-1] + out_val_shape) + else: + val = val.squeeze(-1) + return val + def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( + self.log_potentials, self.lengths + ) return st_gumbel # @constraints.dependent_property @@ -225,84 +265,93 @@ def sample(self, sample_shape=torch.Size(), batch_size=10): samples = [] for k in range(nsamples): if k % batch_size == 0: - sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) + sample = self._struct(MultiSampledSemiring).marginals( + self.log_potentials, lengths=self.lengths + ) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) samples.append(tmp_sample) return torch.stack(samples) - def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10): - r""" - Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)` - - This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples. - As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to - a deterministic distribution that is always uniform over all values. - - The approximation empirically causes a "heavy-hitting" bias where a few configurations are more likely than normal - at the expense of many others, making the tail effectively longer. There is evidence however that temps closer - to 1 reduce this somewhat by smoothing the distribution. - - Parameters: - sample_shape (int): number of samples - temp (float): (default=1.0) relaxation temperature - sample_batch_size (int): size of batches to calculates samples - - Returns: - samples (*sample_shape x batch_shape x event_shape*) - - """ - if type(sample_shape) == int: - nsamples = sample_shape - else: - assert len(sample_shape) == 1 - nsamples = sample_shape[0] - if sample_batch_size > nsamples: - sample_batch_size = nsamples - samples = [] - - if noise_shape is None: - noise_shape = self.log_potentials.shape[1:] - - # print(noise) - assert len(noise_shape) == len(self.log_potentials.shape[1:]) - assert all( - s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:]) - ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}" - - for k in range(nsamples): - if k % sample_batch_size == 0: - shape = self.log_potentials.shape - B = shape[0] - s_log_potentials = ( - self.log_potentials.reshape(1, *shape) - .repeat(sample_batch_size, *tuple(1 for _ in shape)) - .reshape(-1, *shape[1:]) - ) - - s_lengths = self.lengths - if s_lengths is not None: - s_shape = s_lengths.shape - s_lengths = ( - s_lengths.reshape(1, *s_shape) - .repeat(sample_batch_size, *tuple(1 for _ in s_shape)) - .reshape(-1, *s_shape[1:]) - ) - - noise = ( - torch.distributions.Gumbel(0, 1) - .sample((sample_batch_size * B, *noise_shape)) - .expand_as(s_log_potentials) - ).to(s_log_potentials.device) - noisy_potentials = (s_log_potentials + noise) / temp - - r_sample = ( - self._struct(LogSemiring) - .marginals(noisy_potentials, s_lengths) - .reshape(sample_batch_size, B, *shape[1:]) - ) - samples.append(r_sample) - return torch.cat(samples, dim=0)[:nsamples] + # def rsample( + # self, + # sample_shape=torch.Size(), + # temp=1.0, + # noise_shape=None, + # sample_batch_size=10, + # ): + # r""" + # Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)` + + # This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples. + # As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to + # a deterministic distribution that is always uniform over all values. + + # The approximation empirically causes a "heavy-hitting" bias where a few configurations are more likely than normal + # at the expense of many others, making the tail effectively longer. There is evidence however that temps closer + # to 1 reduce this somewhat by smoothing the distribution. + + # Parameters: + # sample_shape (int): number of samples + # temp (float): (default=1.0) relaxation temperature + # sample_batch_size (int): size of batches to calculates samples + + # Returns: + # samples (*sample_shape x batch_shape x event_shape*) + + # """ + # if type(sample_shape) == int: + # nsamples = sample_shape + # else: + # assert len(sample_shape) == 1 + # nsamples = sample_shape[0] + # if sample_batch_size > nsamples: + # sample_batch_size = nsamples + # samples = [] + + # if noise_shape is None: + # noise_shape = self.log_potentials.shape[1:] + + # # print(noise) + # assert len(noise_shape) == len(self.log_potentials.shape[1:]) + # assert all( + # s1 == 1 or s1 == s2 + # for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:]) + # ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}" + + # for k in range(nsamples): + # if k % sample_batch_size == 0: + # shape = self.log_potentials.shape + # B = shape[0] + # s_log_potentials = ( + # self.log_potentials.reshape(1, *shape) + # .repeat(sample_batch_size, *tuple(1 for _ in shape)) + # .reshape(-1, *shape[1:]) + # ) + + # s_lengths = self.lengths + # if s_lengths is not None: + # s_shape = s_lengths.shape + # s_lengths = ( + # s_lengths.reshape(1, *s_shape) + # .repeat(sample_batch_size, *tuple(1 for _ in s_shape)) + # .reshape(-1, *s_shape[1:]) + # ) + + # noise = ( + # torch.distributions.Gumbel(0, 1) + # .sample((sample_batch_size * B, *noise_shape)) + # .expand_as(s_log_potentials) + # ).to(s_log_potentials.device) + # noisy_potentials = (s_log_potentials + noise) / temp + + # r_sample = ( + # self._struct(LogSemiring) + # .marginals(noisy_potentials, s_lengths) + # .reshape(sample_batch_size, B, *shape[1:]) + # ) + # samples.append(r_sample) + # return torch.cat(samples, dim=0)[:nsamples] def to_event(self, sequence, extra, lengths=None): "Convert simple representation to event." @@ -378,7 +427,9 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) + return self.struct( + sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap + ) class HMM(StructDistribution): @@ -541,7 +592,9 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) + super(StructDistribution, self).__init__( + batch_shape=batch_shape, event_shape=event_shape + ) class NonProjectiveDependencyCRF(StructDistribution): diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 7260a518..2d0543e5 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -29,14 +29,13 @@ class _Struct: def __init__(self, semiring=LogSemiring): self.semiring = semiring - def _dp(self, scores, lengths=None, force_grad=False, cache=True): - """Implement computation equivalent to the computing partition constant Z (if self.semiring == `_BaseSemiring`). + def logpartition(self, scores, lengths=None, force_grad=False): + """Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`). Params: scores: torch.FloatTensor, log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) lengths: torch.LongTensor = None, lengths of batch padded examples. Shape = ( * x batch size ) force_grad: bool = False - cache: bool = True Returns: v: torch.Tensor, the resulting output of the dynammic program @@ -114,26 +113,27 @@ def marginals(self, logpotentials, lengths=None, _raw=False): marginals: b x (N-1) x C x C table """ - v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) - if _raw: - all_m = [] - for k in range(v.shape[0]): - obj = v[k].sum(dim=0) - - marg = torch.autograd.grad( - obj, - edges, - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) - return torch.stack(all_m, dim=0) - else: - obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) - a_m = self._arrange_marginals(marg) - return self.semiring.unconvert(a_m) + with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. + v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) + if _raw: + all_m = [] + for k in range(v.shape[0]): + obj = v[k].sum(dim=0) + + marg = torch.autograd.grad( + obj, + edges, + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) + return torch.stack(all_m, dim=0) + else: + obj = self.semiring.unconvert(v).sum(dim=0) + marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) + a_m = self._arrange_marginals(marg) + return self.semiring.unconvert(a_m) @staticmethod def to_parts(spans, extra, lengths=None): @@ -145,12 +145,3 @@ def from_parts(spans): def _arrange_marginals(self, marg): return marg[0] - - # For Testing - def _rand(self, *args, **kwargs): - """TODO:""" - raise NotImplementedError - - def enumerate(self, edge, lengths=None): - """TODO:""" - raise NotImplementedError diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index f7333665..f8af7ae7 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -20,6 +20,9 @@ def matmul(cls, a, b): return c +INF = 1e5 # numerically stable large value + + class Semiring: """ Base semiring class. @@ -28,6 +31,10 @@ class Semiring: * Semiring parsing :cite:`goodman1999semiring` + Attributes: + * zero: the additive identity, subclasses should override + * one: the multiplicative identity, subclasses should override + """ @classmethod @@ -47,6 +54,11 @@ def dot(cls, a, b): b = b.unsqueeze(-1) return cls.matmul(a, b).squeeze(-1).squeeze(-1) + @classmethod + def mul(cls, a, b): + "Multiply a and b under the semirings" + raise NotImplementedError() + @classmethod def times(cls, *ls): "Multiply a list of tensors together" @@ -65,20 +77,20 @@ def unconvert(cls, potentials): "Unconvert from semiring by removing extra first dimension." return potentials.squeeze(0) - @staticmethod - def zero_(xs): + @classmethod + def zero_(cls, xs): "Fill *ssize x ...* tensor with additive identity." - raise NotImplementedError() + return xs.fill_(cls.zero) @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." xs.masked_fill_(mask.unsqueeze(0), cls.zero) - @staticmethod - def one_(xs): + @classmethod + def one_(cls, xs): "Fill *ssize x ...* tensor with multiplicative identity." - raise NotImplementedError() + return xs.fill_(cls.one) @staticmethod def sum(xs, dim=-1): @@ -91,8 +103,8 @@ def plus(cls, a, b): class _Base(Semiring): - zero = 0.0 - one = 1.0 + zero = 0 + one = 1 @staticmethod def mul(a, b): @@ -102,18 +114,10 @@ def mul(a, b): def prod(a, dim=-1): return torch.prod(a, dim=dim) - @staticmethod - def zero_(xs): - return xs.fill_(0) - - @staticmethod - def one_(xs): - return xs.fill_(1) - class _BaseLog(Semiring): - zero = -1e9 - one = 1.0 + zero = -INF + one = 0 @staticmethod def sum(xs, dim=-1): @@ -123,14 +127,6 @@ def sum(xs, dim=-1): def mul(a, b): return a + b - @staticmethod - def zero_(xs): - return xs.fill_(-1e5) - - @staticmethod - def one_(xs): - return xs.fill_(0.0) - @staticmethod def prod(a, dim=-1): return torch.sum(a, dim=dim) @@ -279,7 +275,8 @@ class KLDivergenceSemiring(Semiring): """ - zero = 0 + zero = (-INF, -INF, 0) + one = (0, 0, 0) @staticmethod def size(): @@ -310,7 +307,9 @@ def sum(xs, dim=-1): ( part_p, part_q, - torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d), + torch.sum( + xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d + ), ) ) @@ -325,22 +324,22 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) + xs[2].masked_fill_(mask, cls.zero[2]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) + xs[2].fill_(cls.zero[2]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) + xs[2].fill_(cls.one[2]) return xs @@ -357,7 +356,8 @@ class CrossEntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = (-INF, -INF, 0) + one = (0, 0, 0) @staticmethod def size(): @@ -384,7 +384,9 @@ def sum(xs, dim=-1): log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() - return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d))) + return torch.stack( + (part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)) + ) @staticmethod def mul(a, b): @@ -397,22 +399,22 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) + xs[2].masked_fill_(mask, cls.zero[2]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) + xs[2].fill_(cls.zero[2]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) + xs[2].fill_(cls.one[2]) return xs @@ -429,7 +431,8 @@ class EntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = (-INF, 0) + one = (0, 0) @staticmethod def size(): @@ -466,83 +469,92 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) return xs -class ExpectationSemiring(Semiring): - """ - Implements an value expectation semiring where the value function decomposes additively over parts +def ValueExpectationSemiring(k): + class ValueExpectationSemiring(Semiring): + """ + Implements an value expectation semiring where the value function decomposes additively over parts. - Based on descriptions in: + Based on descriptions in: - * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` - * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` - """ + * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` + * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` - zero = 0 + """ - @staticmethod - def size(): - return 2 + zero = (-INF,) + (0,) * k + one = (0,) * (k + 1) - @staticmethod - def convert(xs): - values = torch.zeros((2,) + xs.shape).type_as(xs) - values[0] = xs - values[1] = 0 - return values + @staticmethod + def size(): + return k + 1 - @staticmethod - def unconvert(xs): - return xs[1] + @staticmethod + def convert(xs): + phis, vals = xs[0], xs[1] + phis = phis + values = torch.zeros((k + 1,) + phis.shape).type_as(vals) + values[0] = phis + for i in range(k): + values[i + 1 :] = vals[i] + return values - @staticmethod - def sum(xs, dim=-1): - assert dim != 0 - d = dim - 1 if dim > 0 else dim - part = torch.logsumexp(xs[0], dim=d) - log_sm = xs[0] - part.unsqueeze(d) - sm = log_sm.exp() - return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d))) + @staticmethod + def unconvert(xs): + return xs[1:] - @staticmethod - def mul(a, b): - return torch.stack((a[0] + b[0], a[1] + b[1])) + @staticmethod + def sum(xs, dim=-1): + assert dim != 0 + d = dim - 1 if dim > 0 else dim + part = torch.logsumexp(xs[0], dim=d) + log_sm = xs[0] - part.unsqueeze(d) + sm = log_sm.exp().unsqueeze(0) + val = torch.sum(xs[1:].mul(sm), dim=d) + return torch.cat((part.unsqueeze(0), val), dim=0) - @classmethod - def prod(cls, xs, dim=-1): - return xs.sum(dim) + @staticmethod + def mul(a, b): + return torch.cat(((a[0] + b[0].unsqueeze(0)), a[1:] + b[1:]), dim=0) - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, 0) + @classmethod + def prod(cls, xs, dim=-1): + return xs.sum(dim) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(0) - return xs + @classmethod + def zero_mask_(cls, xs, mask): + "Fill *ssize x ...* tensor with additive identity." + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1:].masked_fill_(mask, cls.zero[1]) - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - return xs + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1:].fill_(cls.zero[1]) + return xs + + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) + return xs + + return ValueExpectationSemiring def TempMax(alpha): From 6e1704a1b1321a8e3a3b9ece8ca83d792e2c5ef9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 01:20:37 +0000 Subject: [PATCH 07/14] add full cky crfclear --- torch_struct/distributions.py | 82 -------------------- torch_struct/full_cky_crf.py | 142 ++++++++-------------------------- torch_struct/helpers.py | 24 ++++-- 3 files changed, 52 insertions(+), 196 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index cd904594..ba940738 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -197,8 +197,6 @@ def expected_value(self, values): (or a vector of scalar functions). Returns: expected value (*batch_shape, *value_shape) - - """ # Handle value function dimensionality phi_shape = self.log_potentials.shape @@ -273,86 +271,6 @@ def sample(self, sample_shape=torch.Size(), batch_size=10): samples.append(tmp_sample) return torch.stack(samples) - # def rsample( - # self, - # sample_shape=torch.Size(), - # temp=1.0, - # noise_shape=None, - # sample_batch_size=10, - # ): - # r""" - # Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)` - - # This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples. - # As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to - # a deterministic distribution that is always uniform over all values. - - # The approximation empirically causes a "heavy-hitting" bias where a few configurations are more likely than normal - # at the expense of many others, making the tail effectively longer. There is evidence however that temps closer - # to 1 reduce this somewhat by smoothing the distribution. - - # Parameters: - # sample_shape (int): number of samples - # temp (float): (default=1.0) relaxation temperature - # sample_batch_size (int): size of batches to calculates samples - - # Returns: - # samples (*sample_shape x batch_shape x event_shape*) - - # """ - # if type(sample_shape) == int: - # nsamples = sample_shape - # else: - # assert len(sample_shape) == 1 - # nsamples = sample_shape[0] - # if sample_batch_size > nsamples: - # sample_batch_size = nsamples - # samples = [] - - # if noise_shape is None: - # noise_shape = self.log_potentials.shape[1:] - - # # print(noise) - # assert len(noise_shape) == len(self.log_potentials.shape[1:]) - # assert all( - # s1 == 1 or s1 == s2 - # for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:]) - # ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}" - - # for k in range(nsamples): - # if k % sample_batch_size == 0: - # shape = self.log_potentials.shape - # B = shape[0] - # s_log_potentials = ( - # self.log_potentials.reshape(1, *shape) - # .repeat(sample_batch_size, *tuple(1 for _ in shape)) - # .reshape(-1, *shape[1:]) - # ) - - # s_lengths = self.lengths - # if s_lengths is not None: - # s_shape = s_lengths.shape - # s_lengths = ( - # s_lengths.reshape(1, *s_shape) - # .repeat(sample_batch_size, *tuple(1 for _ in s_shape)) - # .reshape(-1, *s_shape[1:]) - # ) - - # noise = ( - # torch.distributions.Gumbel(0, 1) - # .sample((sample_batch_size * B, *noise_shape)) - # .expand_as(s_log_potentials) - # ).to(s_log_potentials.device) - # noisy_potentials = (s_log_potentials + noise) / temp - - # r_sample = ( - # self._struct(LogSemiring) - # .marginals(noisy_potentials, s_lengths) - # .reshape(sample_batch_size, B, *shape[1:]) - # ) - # samples.append(r_sample) - # return torch.cat(samples, dim=0)[:nsamples] - def to_event(self, sequence, extra, lengths=None): "Convert simple representation to event." return self.struct.to_parts(sequence, extra, lengths=None) diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py index 1bff1c14..e77e40d0 100644 --- a/torch_struct/full_cky_crf.py +++ b/torch_struct/full_cky_crf.py @@ -18,12 +18,8 @@ def _check_potentials(self, edge, lengths=None): return edge, semiring_shape, batch, N, NT, lengths - def _dp(self, scores, lengths=None, force_grad=False, cache=True): - DEBUG = False - if DEBUG: - print("FullCKYCRF DP starting") + def logpartition(self, scores, lengths=None, force_grad=False, cache=True): sr = self.semiring - # torch.autograd.set_detect_anomaly(True) # Scores.shape = *sshape, B, N, N, N, NT, NT, NT # w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] @@ -35,18 +31,6 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): # Initialize data structs LEFT, RIGHT = 0, 1 L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim - # Will store sum of subtrees up to i,j,A from the left and right - # beta[LEFT][i,d,A] = sum of potentials of all subtrees in span i,j=(i+d) with nonterminal A - # indexed from the left endpoint i plus the width d - # . = alpha[i,j=(i+d),A] in a nonvectorized version - # beta[RIGHT][j,d',A] = sum of potentials of all subtrees in span i=(j-(N-d')),j with NT A - # indexed from the right endpoint, from widest to shortest subtrees. - # This gets filled in from right to left. - - # OVERRIDE CACHE - cache = False - # print("cache", cache) - beta = [Chart((b, N, N, NT), scores, sr, cache=cache) for _ in range(2)] # Initialize the base cases with scores from diagonal i=j=k, A=B=C term_scores = ( @@ -55,26 +39,16 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): .diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 .diagonal(0, -3, -1) # diag of C with that gives A=B=C ) - assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" - beta[LEFT][:, 0, :] = term_scores - beta[RIGHT][:, N - 1, :] = term_scores + assert term_scores.shape[S + 1 :] == ( + N, + NT, + ), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" alpha_left = term_scores alpha_right = term_scores - - ### old: init with semiring's multiplicative identity, gives zeros mass to leaves - # ns = torch.arange(NT) - # beta[LEFT][:, 0, :] = sr.one_(beta[LEFT][:, 0, :]) - # beta[RIGHT][:, N - 1, :] = sr.one_(beta[RIGHT][:, N - 1, :]) - # alpha_left = sr.one_(torch.ones(sshape + [b, N, NT]).to(scores.device)) - # alpha_right = sr.one_(torch.ones(sshape + [b, N, NT]).to(scores.device)) - alphas = [[alpha_left], [alpha_right]] # Run vectorized inside alg - - ws = tqdm(range(1, N), "Calculating marginals at width", N - 1) if DEBUG else range(1, N) - for w in ws: - # print("\nw", w, "N-w", N - w) + for w in range(1, N): # Scores # What we want is a tensor with: # shape: *sshape, batch, (N-w), NT, w, NT, NT @@ -82,107 +56,59 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): # where (i,j=i+w) means the diagonal of trees nodes with width w # Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores - # print("diagonal", score.shape[S:]) - score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front - # print("permute", score.shape[S:]) + score = score.permute( + sdims + [-6, -1, -4, -5, -3, -2] + ) # move diag (-1) dim and head NT (-4) dim to front score = score[..., :w, :, :] # remove illegal splitpoints - # print("slice", score.shape[S:]) - assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" - # print("S", score[0, 0, :, 0, :, 0, 0].exp()) + assert score.shape[S:] == ( + batch, + N - w, + NT, + w, + NT, + NT, + ), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" + # Sums of left subtrees # Shape: *sshape, batch, (N-w), w, NT # where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) left = slice(None, N - w) # left indices - L1 = beta[LEFT][left, :w] L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] - assert L.isclose(L1).all() - # print("L", L.shape) - # Sums of right subtrees # Shape: *sshape, batch, (N-w), w, NT # where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) right = slice(w, None) # right indices - R1 = beta[RIGHT][right, N - w :] R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] - assert R.isclose(R1).all() - # print("R", R.shape) # R[0, 0, :, :, 0].exp()) # Broadcast them both to match missing dims in score # Left B is duplicated for all head and right symbols A C - L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT]) - # Right C is duplicated for all head and left symbols A B - R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1]) + L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]) - assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) - # print(score.shape[S + 1 :], L_bcast.shape, R_bcast.shape) + # Right C is duplicated for all head and left symbols A B + R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]) # Now multiply all the scores and sum over k, B, C dimensions (the last three dims) - assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + assert sr.times(score, L_bcast, R_bcast).shape == tuple( + list(sshape) + [b, N - w, NT, w, NT, NT] + ) sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) - # print("sum prod w", sum_prod_w.exp()) - assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" + # sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1)) + assert sum_prod_w.shape[S:] == ( + b, + N - w, + NT, + ), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" - # new = sr.times(sr.dot(Y, Z), score) - beta[LEFT][left, w] = sum_prod_w - beta[RIGHT][right, N - w - 1] = sum_prod_w - # pad = sr.zero_(torch.ones_like(sum_prod_w))[..., :w, :] pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) - # print(sum_prod_w.shape, sum_prod_w_left.shape, sum_prod_w_right.shape) alphas[LEFT].append(sum_prod_w_left) alphas[RIGHT].append(sum_prod_w_right) - # for c in range(NT): - # print(f"left c:{c}\n", beta[LEFT][:, :].exp().detach().numpy()) - - # print(f"right c:{c}\n", beta[RIGHT][:, :].exp().detach().numpy()) - - final1 = sr.sum(beta[LEFT][0, :, :]) - final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol - # print(f"f1:{final1.shape}, f:{final.shape}, ls:{lengths}") - assert final.isclose(final1).all(), f"final:\n{final}\nfinal1:\n{final1}" - # log_Z = final[..., 0, lengths - 1] + final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[ + ..., 0, : + ] # sum out root symbol log_Z = final[:, torch.arange(batch), lengths - 1] - # log_Z.exp().sum().backward() - # print("Z", log_Z.exp()) - # if DEBUG: - # print("Using autograd to get marginals") - return log_Z, [scores], beta - - # For testing - - def enumerate(self, scores, lengths=None): - raise NotImplementedError - semiring = self.semiring - batch, N, _, _, NT, _, _ = scores.shape - - def enumerate(x, start, end): - if start + 1 == end: - yield (scores[:, start, start, x], [(start, x)]) - else: - for w in range(start + 1, end): - for y in range(NT): - for z in range(NT): - for m1, y1 in enumerate(y, start, w): - for m2, z1 in enumerate(z, w, end): - yield ( - semiring.times(m1, m2, scores[:, start, end - 1, x]), - [(x, start, w, end)] + y1 + z1, - ) - - ls = [] - for nt in range(NT): - ls += [s for s, _ in enumerate(nt, 0, N)] - - return semiring.sum(torch.stack(ls, dim=-1)), None - - @staticmethod - def _rand(): - batch = torch.randint(2, 5, (1,)) - N = torch.randint(2, 5, (1,)) - NT = torch.randint(2, 5, (1,)) - scores = torch.rand(batch, N, N, N, NT, NT, NT) - return scores, (batch.item(), N.item()) + return log_Z, [scores], alphas \ No newline at end of file diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 2d0543e5..3daab4c6 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -6,7 +6,11 @@ class Chart: def __init__(self, size, potentials, semiring): self.data = semiring.zero_( - torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) + torch.zeros( + *((semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device + ) ) self.grad = self.data.detach().clone().fill_(0.0) @@ -22,7 +26,7 @@ def __setitem__(self, ind, new): class _Struct: """`_Struct` is base class used to represent the graphical structure of a model. - Subclasses should implement a `_dp` method which computes the partition function (under the standard `_BaseSemiring`). + Subclasses should implement a `logpartition` method which computes the partition function (under the standard `_BaseSemiring`). Different `StructDistribution` methods will instantiate the `_Struct` subclasses """ @@ -54,7 +58,7 @@ def score(self, potentials, parts, batch_dims=[0]): """Score for entire structure is product of potentials for all activated "parts".""" score = torch.mul(potentials, parts) # mask potentials by activated "parts" batch = tuple((score.shape[b] for b in batch_dims)) - return self.semiring.prod(score.view(batch + (-1,))) # product of all potentialsa + return self.semiring.prod(score.view(batch + (-1,))) def _bin_length(self, length): """Find least upper bound for lengths that is a power of 2. Used in parallel scans.""" @@ -78,7 +82,11 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) + torch.zeros( + *((self.semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device + ) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -114,7 +122,9 @@ def marginals(self, logpotentials, lengths=None, _raw=False): """ with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. - v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) + v, edges = self.logpartition( + logpotentials, lengths=lengths, force_grad=True + ) if _raw: all_m = [] for k in range(v.shape[0]): @@ -131,7 +141,9 @@ def marginals(self, logpotentials, lengths=None, _raw=False): return torch.stack(all_m, dim=0) else: obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) + marg = torch.autograd.grad( + obj, edges, create_graph=True, only_inputs=True, allow_unused=False + ) a_m = self._arrange_marginals(marg) return self.semiring.unconvert(a_m) From 2d0abe88ebce993378cdfddcb067529e2557190b Mon Sep 17 00:00:00 2001 From: Tom Effland Date: Wed, 20 Jan 2021 13:08:07 -0500 Subject: [PATCH 08/14] Update full_cky_crf.py --- torch_struct/full_cky_crf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py index e77e40d0..5457e620 100644 --- a/torch_struct/full_cky_crf.py +++ b/torch_struct/full_cky_crf.py @@ -93,8 +93,8 @@ def logpartition(self, scores, lengths=None, force_grad=False, cache=True): assert sr.times(score, L_bcast, R_bcast).shape == tuple( list(sshape) + [b, N - w, NT, w, NT, NT] ) - sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) - # sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1)) +# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1)) assert sum_prod_w.shape[S:] == ( b, N - w, @@ -111,4 +111,4 @@ def logpartition(self, scores, lengths=None, force_grad=False, cache=True): ..., 0, : ] # sum out root symbol log_Z = final[:, torch.arange(batch), lengths - 1] - return log_Z, [scores], alphas \ No newline at end of file + return log_Z, [scores] From 297209a7b28ab37eb2dd7954f601906903f0906e Mon Sep 17 00:00:00 2001 From: Sasha Rush Date: Wed, 20 Jan 2021 13:51:29 -0500 Subject: [PATCH 09/14] Update helpers.py --- torch_struct/helpers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 3daab4c6..b4b93c19 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -36,14 +36,14 @@ def __init__(self, semiring=LogSemiring): def logpartition(self, scores, lengths=None, force_grad=False): """Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`). - Params: - scores: torch.FloatTensor, log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) - lengths: torch.LongTensor = None, lengths of batch padded examples. Shape = ( * x batch size ) + Parameters: + scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) + lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size ) force_grad: bool = False Returns: - v: torch.Tensor, the resulting output of the dynammic program - edges: List[torch.Tensor], the log edge potentials of the model. + v (torch.Tensor) : the resulting output of the dynammic program + edges (List[torch.Tensor]): the log edge potentials of the model. When `scores` is already in a log_potential format for the distribution (typical), this will be [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. An exceptional case is the `CKY` struct, which takes log potential parameters from production rules From 657fbc653e26a4eaecc41a1e83f18cfc4197adb3 Mon Sep 17 00:00:00 2001 From: Sasha Rush Date: Wed, 20 Jan 2021 13:56:29 -0500 Subject: [PATCH 10/14] Update distributions.py --- torch_struct/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index ba940738..fdcb96e4 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -181,7 +181,7 @@ def marginals(self): @lazy_property def count(self): - "Compute the total number of parts in structure with non-zero probability." + "Compute the total number of structures in the CRF support set." ones = torch.ones_like(self.log_potentials) ones[self.log_potentials.eq(-float("inf"))] = 0 return self._struct(StdSemiring).sum(ones, self.lengths) From 6edceb0a5415f55c49da80497159dce19469643b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 20:12:32 +0000 Subject: [PATCH 11/14] address review suggestions --- torch_struct/distributions.py | 14 +++++++------- torch_struct/full_cky_crf.py | 9 +++++---- torch_struct/helpers.py | 2 +- torch_struct/semirings/sample.py | 2 ++ 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index fdcb96e4..d12cac2b 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -190,11 +190,10 @@ def expected_value(self, values): """ Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. - Params: - * values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part - of the structure. `values` can have 0 or more training dimensions in addition to the `event_shape`, - which allows for computing the expected value of, say, a vector valued function - (or a vector of scalar functions). + Parameters: + * values (:class: torch.FloatTensor): (*batch_shape x *event_shape, *value_shape), assigns a value to each + part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, + which allows for computing the expected value of, say, a vector valued function. Returns: expected value (*batch_shape, *value_shape) """ @@ -244,7 +243,7 @@ def partition(self): "Compute the log-partition function." return self._struct(LogSemiring).sum(self.log_potentials, self.lengths) - def sample(self, sample_shape=torch.Size(), batch_size=10): + def sample(self, sample_shape=torch.Size()): r""" Compute structured samples from the distribution :math:`z \sim p(z)`. @@ -255,6 +254,7 @@ def sample(self, sample_shape=torch.Size(), batch_size=10): Returns: samples (*sample_shape x batch_shape x event_shape*) """ + batch_size = MultiSampledSemiring.batch_size if type(sample_shape) == int: nsamples = sample_shape else: @@ -474,7 +474,7 @@ class FullTreeCRF(StructDistribution): Implementation uses width-batched, forward-pass only * Parallel Time: :math:`O(N)` parallel merges. - * Forward Memory: :math:`O(N^2)` + * Forward Memory: :math:`O(N^3)` Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) """ diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py index 5457e620..d42a5105 100644 --- a/torch_struct/full_cky_crf.py +++ b/torch_struct/full_cky_crf.py @@ -1,6 +1,5 @@ import torch -from .helpers import _Struct, Chart -from tqdm import tqdm +from .helpers import _Struct A, B = 0, 1 @@ -93,8 +92,10 @@ def logpartition(self, scores, lengths=None, force_grad=False, cache=True): assert sr.times(score, L_bcast, R_bcast).shape == tuple( list(sshape) + [b, N - w, NT, w, NT, NT] ) -# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) - sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1)) + # sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + sum_prod_w = sr.sum( + sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3], -1) + ) assert sum_prod_w.shape[S:] == ( b, N - w, diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index b4b93c19..1335e652 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -43,7 +43,7 @@ def logpartition(self, scores, lengths=None, force_grad=False): Returns: v (torch.Tensor) : the resulting output of the dynammic program - edges (List[torch.Tensor]): the log edge potentials of the model. + logpotentials (List[torch.Tensor]): the log edge potentials of the model. When `scores` is already in a log_potential format for the distribution (typical), this will be [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. An exceptional case is the `CKY` struct, which takes log potential parameters from production rules diff --git a/torch_struct/semirings/sample.py b/torch_struct/semirings/sample.py index 09ec189c..8628da0a 100644 --- a/torch_struct/semirings/sample.py +++ b/torch_struct/semirings/sample.py @@ -215,6 +215,8 @@ class MultiSampledSemiring(_BaseLog): "Gradients" give up to 16 samples with replacement. """ + batch_size = 10 + @staticmethod def sum(xs, dim=-1): return _MultiSampledLogSumExp.apply(xs, dim) From 19982f71ee96bd603ead76272d0f7870fdee638b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 20:29:53 +0000 Subject: [PATCH 12/14] fix doc string errors --- torch_struct/distributions.py | 8 ++++---- torch_struct/helpers.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index d12cac2b..36b8a663 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -191,9 +191,10 @@ def expected_value(self, values): Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. Parameters: - * values (:class: torch.FloatTensor): (*batch_shape x *event_shape, *value_shape), assigns a value to each - part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, - which allows for computing the expected value of, say, a vector valued function. + values (:class: torch.FloatTensor): (*batch_shape x *event_shape, *value_shape), assigns a value to each + part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, + which allows for computing the expected value of, say, a vector valued function. + Returns: expected value (*batch_shape, *value_shape) """ @@ -249,7 +250,6 @@ def sample(self, sample_shape=torch.Size()): Parameters: sample_shape (int): number of samples - batch_size (int): number of samples to compute at a time Returns: samples (*sample_shape x batch_shape x event_shape*) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 1335e652..9121d204 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -37,22 +37,20 @@ def logpartition(self, scores, lengths=None, force_grad=False): """Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`). Parameters: - scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) - lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size ) - force_grad: bool = False + scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) + lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size ) + force_grad: bool = False Returns: - v (torch.Tensor) : the resulting output of the dynammic program - logpotentials (List[torch.Tensor]): the log edge potentials of the model. + v (torch.Tensor) : the resulting output of the dynammic program + logpotentials (List[torch.Tensor]): the log edge potentials of the model. When `scores` is already in a log_potential format for the distribution (typical), this will be [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. An exceptional case is the `CKY` struct, which takes log potential parameters from production rules for a PCFG, which are by definition independent of position in the sequence. - charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the - "backward" dynamic program and compute things like marginals w/o autograd. """ - raise NotImplementedError + raise NotImplementedError() def score(self, potentials, parts, batch_dims=[0]): """Score for entire structure is product of potentials for all activated "parts".""" From 71004b221af3ccbfb18177101687d5b00a162e18 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 22 Jan 2021 15:46:10 +0000 Subject: [PATCH 13/14] switch value expectation to elementwise mul and reduce --- torch_struct/distributions.py | 70 ++++++------------------- torch_struct/semirings/semirings.py | 80 +---------------------------- 2 files changed, 19 insertions(+), 131 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 36b8a663..d8c0d3d2 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -18,7 +18,6 @@ KMaxSemiring, StdSemiring, GumbelCRFSemiring, - ValueExpectationSemiring, ) @@ -93,9 +92,7 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) def kl(self, other): """ @@ -107,9 +104,7 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) @lazy_property def max(self): @@ -142,9 +137,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -157,9 +150,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -191,44 +182,23 @@ def expected_value(self, values): Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. Parameters: - values (:class: torch.FloatTensor): (*batch_shape x *event_shape, *value_shape), assigns a value to each + values (:class: torch.FloatTensor): (*batch_shape x *event_shape x *value_shape), assigns a value to each part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, which allows for computing the expected value of, say, a vector valued function. Returns: expected value (*batch_shape, *value_shape) """ - # Handle value function dimensionality - phi_shape = self.log_potentials.shape - extra_dims = len(values.shape) - len(phi_shape) - if extra_dims: - # Extra dims get flattened and put in front - out_val_shape = values.shape[len(phi_shape) :] - values = values.reshape(*phi_shape, -1) - values = values.permute([-1] + list(range(len(phi_shape)))) - k = values.shape[0] - else: - out_val_shape = None - k = 1 - - # Compute expected value - val = self._struct(ValueExpectationSemiring(k)).sum( - [self.log_potentials, values], self.lengths - ) - - # Reformat dimensions to match input dimensions - val = val.permute(list(range(1, len(val.shape))) + [0]) - if out_val_shape is not None: - val = val.reshape(*val.shape[:-1] + out_val_shape) - else: - val = val.squeeze(-1) - return val + # For these "part-level" expectations, this can be computed by multiplying the marginals element-wise + # on the values and summing. This is faster than the semiring because of FastLogSemiring. + # (w/o genbmm it's about the same.) + ps = self.marginals + ps_bcast = ps.reshape(*ps.shape, *((1,) * (len(values.shape) - len(ps.shape)))) + return ps_bcast.mul(values).reshape(ps.shape[0], -1, *values.shape[len(ps.shape) :]).sum(1) def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( - self.log_potentials, self.lengths - ) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) return st_gumbel # @constraints.dependent_property @@ -263,9 +233,7 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % batch_size == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) samples.append(tmp_sample) @@ -345,9 +313,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -474,9 +440,9 @@ class FullTreeCRF(StructDistribution): Implementation uses width-batched, forward-pass only * Parallel Time: :math:`O(N)` parallel merges. - * Forward Memory: :math:`O(N^3)` + * Forward Memory: :math:`O(N^3 NT^3)` - Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) + Compact representation: *N x N x N x NT x NT x NT* long tensor (Same) """ struct = Full_CKY_CRF @@ -510,9 +476,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index f8af7ae7..9d7dd525 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -307,9 +307,7 @@ def sum(xs, dim=-1): ( part_p, part_q, - torch.sum( - xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d - ), + torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d), ) ) @@ -384,9 +382,7 @@ def sum(xs, dim=-1): log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() - return torch.stack( - (part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)) - ) + return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d))) @staticmethod def mul(a, b): @@ -485,78 +481,6 @@ def one_(cls, xs): return xs -def ValueExpectationSemiring(k): - class ValueExpectationSemiring(Semiring): - """ - Implements an value expectation semiring where the value function decomposes additively over parts. - - Based on descriptions in: - - * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` - * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` - - """ - - zero = (-INF,) + (0,) * k - one = (0,) * (k + 1) - - @staticmethod - def size(): - return k + 1 - - @staticmethod - def convert(xs): - phis, vals = xs[0], xs[1] - phis = phis - values = torch.zeros((k + 1,) + phis.shape).type_as(vals) - values[0] = phis - for i in range(k): - values[i + 1 :] = vals[i] - return values - - @staticmethod - def unconvert(xs): - return xs[1:] - - @staticmethod - def sum(xs, dim=-1): - assert dim != 0 - d = dim - 1 if dim > 0 else dim - part = torch.logsumexp(xs[0], dim=d) - log_sm = xs[0] - part.unsqueeze(d) - sm = log_sm.exp().unsqueeze(0) - val = torch.sum(xs[1:].mul(sm), dim=d) - return torch.cat((part.unsqueeze(0), val), dim=0) - - @staticmethod - def mul(a, b): - return torch.cat(((a[0] + b[0].unsqueeze(0)), a[1:] + b[1:]), dim=0) - - @classmethod - def prod(cls, xs, dim=-1): - return xs.sum(dim) - - @classmethod - def zero_mask_(cls, xs, mask): - "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, cls.zero[0]) - xs[1:].masked_fill_(mask, cls.zero[1]) - - @classmethod - def zero_(cls, xs): - xs[0].fill_(cls.zero[0]) - xs[1:].fill_(cls.zero[1]) - return xs - - @classmethod - def one_(cls, xs): - xs[0].fill_(cls.one[0]) - xs[1].fill_(cls.one[1]) - return xs - - return ValueExpectationSemiring - - def TempMax(alpha): class _TempMax(_BaseLog): """ From cded5e1b95c395ad430e72d61baf9f49be18825c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 22 Jan 2021 15:57:52 +0000 Subject: [PATCH 14/14] darglint ignore logpartition docstring mismatch --- torch_struct/helpers.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 9121d204..58116a36 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -6,11 +6,7 @@ class Chart: def __init__(self, size, potentials, semiring): self.data = semiring.zero_( - torch.zeros( - *((semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ) self.grad = self.data.detach().clone().fill_(0.0) @@ -49,6 +45,7 @@ def logpartition(self, scores, lengths=None, force_grad=False): An exceptional case is the `CKY` struct, which takes log potential parameters from production rules for a PCFG, which are by definition independent of position in the sequence. + # noqa: DAR401, DAR202 """ raise NotImplementedError() @@ -80,11 +77,7 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros( - *((self.semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -120,9 +113,7 @@ def marginals(self, logpotentials, lengths=None, _raw=False): """ with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. - v, edges = self.logpartition( - logpotentials, lengths=lengths, force_grad=True - ) + v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) if _raw: all_m = [] for k in range(v.shape[0]): @@ -139,9 +130,7 @@ def marginals(self, logpotentials, lengths=None, _raw=False): return torch.stack(all_m, dim=0) else: obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) + marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) a_m = self._arrange_marginals(marg) return self.semiring.unconvert(a_m)