Skip to content

Commit 00b56dd

Browse files
authored
Merge pull request #18 from codefuse-ai/modelcache_localDB_dev
add timm for embedding
2 parents 6de9968 + 0a1c2f1 commit 00b56dd

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

modelcache/embedding/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb")
66
fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext")
77
paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp")
8+
timm = LazyImport("timm", globals(), "modelcache.embedding.timm")
89

910

1011
def Huggingface(model="sentence-transformers/all-mpnet-base-v2"):
@@ -25,3 +26,7 @@ def FastText(model="en", dim=None):
2526

2627
def PaddleNLP(model="ernie-3.0-medium-zh"):
2728
return paddlenlp.PaddleNLP(model)
29+
30+
31+
def Timm(model="resnet50", device="default"):
32+
return timm.Timm(model, device)

modelcache/embedding/timm.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
import numpy as np
3+
4+
from modelcache.utils import import_timm, import_torch, import_pillow
5+
from modelcache.embedding.base import BaseEmbedding
6+
7+
import_torch()
8+
import_timm()
9+
import_pillow()
10+
11+
import torch # pylint: disable=C0413
12+
from timm.models import create_model # pylint: disable=C0413
13+
from timm.data import create_transform, resolve_data_config # pylint: disable=C0413
14+
from PIL import Image # pylint: disable=C0413
15+
16+
17+
class Timm(BaseEmbedding):
18+
def __init__(self, model: str = "resnet18", device: str = "default"):
19+
if device == "default":
20+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
21+
else:
22+
self.device = device
23+
self.model_name = model
24+
self.model = create_model(model_name=model, pretrained=True)
25+
self.model.eval()
26+
27+
try:
28+
self.__dimension = self.model.embed_dim
29+
except Exception: # pylint: disable=W0703
30+
self.__dimension = None
31+
32+
def to_embeddings(self, data, skip_preprocess: bool = False, **_):
33+
if not skip_preprocess:
34+
data = self.preprocess(data)
35+
if data.dim() == 3:
36+
data = data.unsqueeze(0)
37+
feats = self.model.forward_features(data)
38+
emb = self.post_proc(feats).squeeze(0).detach().numpy()
39+
40+
return np.array(emb).astype("float32")
41+
42+
def post_proc(self, features):
43+
features = features.to("cpu")
44+
if features.dim() == 3:
45+
features = features[:, 0]
46+
if features.dim() == 4:
47+
global_pool = torch.nn.AdaptiveAvgPool2d(1)
48+
features = global_pool(features)
49+
features = features.flatten(1)
50+
assert features.dim() == 2, f"Invalid output dim {features.dim()}"
51+
return features
52+
53+
def preprocess(self, image_path):
54+
data_cfg = resolve_data_config(self.model.pretrained_cfg)
55+
transform = create_transform(**data_cfg)
56+
57+
image = Image.open(image_path).convert("RGB")
58+
image_tensor = transform(image)
59+
return image_tensor
60+
61+
@property
62+
def dimension(self):
63+
"""Embedding dimension.
64+
65+
:return: embedding dimension
66+
"""
67+
if not self.__dimension:
68+
input_size = self.model.pretrained_cfg["input_size"]
69+
dummy_input = torch.rand((1,) + input_size)
70+
feats = self.to_embeddings(dummy_input, skip_preprocess=True)
71+
self.__dimension = feats.shape[0]
72+
return self.__dimension

modelcache/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ def import_paddle():
6161

6262
def import_paddlenlp():
6363
_check_library("paddlenlp")
64+
65+
66+
def import_timm():
67+
_check_library("timm", package="timm")
68+
69+
70+
def import_pillow():
71+
_check_library("PIL", package="pillow")

0 commit comments

Comments
 (0)