From 42e757da38e2af3e50eab84e5bee403cac96f2c6 Mon Sep 17 00:00:00 2001 From: Nikhil Shenoy Date: Sat, 2 Mar 2024 00:17:36 +0000 Subject: [PATCH 1/8] Added vector output from skew-symmetric tensor --- torchmdnet/models/tensornet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..6abdfafc0 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -48,6 +48,10 @@ def vector_to_symtensor(vector): S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return S +def skewtensor_to_vector(tensor): + '''Converts a skew-symmetric tensor to a vector.''' + return torch.stack((tensor[:, :, 1, 2], tensor[:, :, 2, 0], tensor[:, :, 0, 1]), dim=-1) + def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" @@ -265,10 +269,13 @@ def forward( x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) x = self.act(self.linear((x))) + v = skewtensor_to_vector(A) + v = v.transpose(1, 2) # # Remove the extra atom if self.static_shapes: x = x[:-1] - return x, None, z, pos, batch + v = v[:-1] + return x, v, z, pos, batch class TensorEmbedding(nn.Module): From 6925fc05c028c7bf3016a1178b41ed68ba81c5f2 Mon Sep 17 00:00:00 2001 From: Nikhil Shenoy Date: Thu, 7 Mar 2024 01:36:38 +0000 Subject: [PATCH 2/8] Added vector output option and added equivariance test --- tests/test_equivariance.py | 6 ++++-- tests/utils.py | 2 +- torchmdnet/models/model.py | 11 +++++++++++ torchmdnet/models/tensornet.py | 20 ++++++++++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 1492a9f07..920486254 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -2,6 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) +import pytest import torch from torchmdnet.models.model import create_model from utils import load_example_args @@ -27,7 +28,8 @@ def test_scalar_invariance(): torch.testing.assert_allclose(y, y_rot) -def test_vector_equivariance(): +@pytest.mark.parametrize("model_name", ["equivariant-transformer", "equivariant-tensornet"]) +def test_vector_equivariance(model_name): torch.manual_seed(1234) rotate = torch.tensor( [ @@ -39,7 +41,7 @@ def test_vector_equivariance(): model = create_model( load_example_args( - "equivariant-transformer", + model_name, prior_model=None, output_model="VectorOutput", ) diff --git a/tests/utils.py b/tests/utils.py index ef8bcddb9..0de2337c8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,7 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs): if config_file is None: - if model_name == "tensornet": + if model_name == "tensornet" or model_name == "equivariant-tensornet": config_file = join(dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml") else: config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..b8f658088 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -103,6 +103,17 @@ def create_model(args, prior_model=None, mean=None, std=None): static_shapes=args["static_shapes"], **shared_args, ) + elif args["model"] == "equivariant-tensornet": + from torchmdnet.models.tensornet import TensorNet + + # returns an equivariant vector + is_equivariant = True + representation_model = TensorNet( + equivariance_invariance_group=args["equivariance_invariance_group"], + static_shapes=args["static_shapes"], + vector_output=True, + **shared_args, + ) else: raise ValueError(f'Unknown architecture: {args["model"]}') diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 6abdfafc0..1a15abfbf 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -124,6 +124,7 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + vector_output (bool, optional): Whether to return """ def __init__( @@ -143,6 +144,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + vector_output=False ): super(TensorNet, self).__init__() @@ -214,6 +216,7 @@ def __init__( box=box_vecs, long_edge_index=True, ) + self.vector_output = vector_output self.reset_parameters() @@ -269,12 +272,21 @@ def forward( x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) x = self.act(self.linear((x))) - v = skewtensor_to_vector(A) - v = v.transpose(1, 2) - # # Remove the extra atom + # Remove the extra atom if self.static_shapes: x = x[:-1] - v = v[:-1] + + # calculate vector_output if needed + v = None + if self.vector_output: + # (n_atoms, hidden_channels, 3, 3) -> (n_atoms, hidden_channels, 3) + v = skewtensor_to_vector(A) + # (n_atoms, hidden_channels, 3) -> (n_atoms, 3, hidden_channels) + v = v.transpose(1, 2) + + if self.static_shapes: + v = v[:-1] + return x, v, z, pos, batch From 66521a4ca96dc9d413687def71dbfdb109a5755a Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:57:27 +0200 Subject: [PATCH 3/8] Update test_equivariance.py --- tests/test_equivariance.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 920486254..111c093f3 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -28,7 +28,7 @@ def test_scalar_invariance(): torch.testing.assert_allclose(y, y_rot) -@pytest.mark.parametrize("model_name", ["equivariant-transformer", "equivariant-tensornet"]) +@pytest.mark.parametrize("model_name", ["equivariant-transformer", "tensornet"]) def test_vector_equivariance(model_name): torch.manual_seed(1234) rotate = torch.tensor( @@ -38,14 +38,23 @@ def test_vector_equivariance(model_name): [-0.0626055, 0.3134752, 0.9475304], ] ) - - model = create_model( - load_example_args( - model_name, - prior_model=None, - output_model="VectorOutput", + if model_name == "equivariant_transformer" + model = create_model( + load_example_args( + model_name, + prior_model=None, + output_model="VectorOutput", + ) + ) + if model_name == "tensornet" + model = create_model( + load_example_args( + model_name, + prior_model=None, + vector_output=True, + output_model="VectorOutput", + ) ) - ) z = torch.ones(100, dtype=torch.long) pos = torch.randn(100, 3) batch = torch.arange(50, dtype=torch.long).repeat_interleave(2) From 433d59cda2e1511ba5be4513478aa27018f5435d Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:58:06 +0200 Subject: [PATCH 4/8] Update utils.py --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 0de2337c8..ef8bcddb9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,7 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs): if config_file is None: - if model_name == "tensornet" or model_name == "equivariant-tensornet": + if model_name == "tensornet": config_file = join(dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml") else: config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") From 56acf753dd6f011c0414482b162d7d468e9d15ee Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:59:03 +0200 Subject: [PATCH 5/8] Update model.py --- torchmdnet/models/model.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index b8f658088..e4f8dd653 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -101,17 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None): representation_model = TensorNet( equivariance_invariance_group=args["equivariance_invariance_group"], static_shapes=args["static_shapes"], - **shared_args, - ) - elif args["model"] == "equivariant-tensornet": - from torchmdnet.models.tensornet import TensorNet - - # returns an equivariant vector - is_equivariant = True - representation_model = TensorNet( - equivariance_invariance_group=args["equivariance_invariance_group"], - static_shapes=args["static_shapes"], - vector_output=True, + vector_output=args["vector_output"], **shared_args, ) else: From 060cc09fc68bea5ce449a83f5c71617abe211855 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:59:39 +0200 Subject: [PATCH 6/8] Update tensornet.py --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 1a15abfbf..0d9ad1ce5 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -124,7 +124,7 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - vector_output (bool, optional): Whether to return + vector_output (bool, optional): Whether to return vector features per atom """ def __init__( From b26f45bd4c1e23e6dcaab0b848e9a8bc214e9947 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:02:16 +0200 Subject: [PATCH 7/8] Update train.py --- torchmdnet/scripts/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..be744bf57 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -103,6 +103,7 @@ def get_argparse(): `a[1] = a[2] = b[2] = 0`;`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`;`a[0] >= 2*b[0]`;`a[0] >= 2*c[0]`;`b[1] >= 2*c[1]`; These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the requirement that the cutoff be no larger than half the box width. Example: [[1,0,0],[0,1,0],[0,0,1]]""") + parser.add_argument('--vector-output', type=bool, default=False, help='If true, returns vector features per atom on top of scalars') parser.add_argument('--static_shapes', type=bool, default=False, help='If true, TensorNet will use statically shaped tensors for the network, making it capturable into a CUDA graphs. In some situations static shapes can lead to a speedup, but it increases memory usage.') # other args From d47493a4c0761e68ba0dce92de15d4618d72ba7f Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:05:40 +0200 Subject: [PATCH 8/8] Update test_equivariance.py fix --- tests/test_equivariance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 111c093f3..7e6b178ea 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -38,7 +38,7 @@ def test_vector_equivariance(model_name): [-0.0626055, 0.3134752, 0.9475304], ] ) - if model_name == "equivariant_transformer" + if model_name == "equivariant_transformer": model = create_model( load_example_args( model_name, @@ -46,7 +46,7 @@ def test_vector_equivariance(model_name): output_model="VectorOutput", ) ) - if model_name == "tensornet" + if model_name == "tensornet": model = create_model( load_example_args( model_name,