diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index 71bb271c..c9840c9d 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -45,7 +45,7 @@ def _check_potentials(self, edge, lengths=None): assert max(lengths) == N, "One length must be at least N" return edge, batch, N, M, lengths - def logparition(self, log_potentials, lengths=None, force_grad=False, cache=True): + def logpartition(self, log_potentials, lengths=None, force_grad=False, cache=True): return self._dp_scan(log_potentials, lengths, force_grad) def _dp_scan(self, log_potentials, lengths=None, force_grad=False):