Skip to content

Commit 560c862

Browse files
committed
add redis vector store
1 parent 59def83 commit 560c862

File tree

10 files changed

+133
-73
lines changed

10 files changed

+133
-73
lines changed

examples/flask/register.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
register index for redis
4+
"""
5+
import json
6+
import requests
7+
8+
9+
def run():
10+
url = 'http://127.0.0.1:5000/modelcache'
11+
type = 'register'
12+
scope = {"model": "CODEGPT-1117"}
13+
data = {'type': type, 'scope': scope}
14+
headers = {"Content-Type": "application/json"}
15+
res = requests.post(url, headers=headers, json=json.dumps(data))
16+
res_text = res.text
17+
print('res_text: {}'.format(res_text))
18+
19+
20+
if __name__ == '__main__':
21+
run()

flask4modelcache.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22
import time
3-
from datetime import datetime
43
from flask import Flask, request
54
import logging
65
import configparser
@@ -15,7 +14,6 @@
1514
from modelcache.utils.model_filter import model_blacklist_filter
1615
from modelcache.embedding import Data2VecAudio
1716

18-
1917
# 创建一个Flask实例
2018
app = Flask(__name__)
2119

@@ -36,13 +34,19 @@ def response_hitquery(cache_resp):
3634
data2vec = Data2VecAudio()
3735
mysql_config = configparser.ConfigParser()
3836
mysql_config.read('modelcache/config/mysql_config.ini')
37+
3938
milvus_config = configparser.ConfigParser()
4039
milvus_config.read('modelcache/config/milvus_config.ini')
41-
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
42-
# VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
40+
41+
# redis_config = configparser.ConfigParser()
42+
# redis_config.read('modelcache/config/redis_config.ini')
43+
4344

4445
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
45-
VectorBase("redis", dimension=data2vec.dimension, milvus_config=milvus_config))
46+
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
47+
48+
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
49+
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))
4650

4751

4852
cache.init(
@@ -88,9 +92,9 @@ def user_backend():
8892
model = model.replace('.', '_')
8993
query = param_dict.get("query")
9094
chat_info = param_dict.get("chat_info")
91-
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
95+
if request_type is None or request_type not in ['query', 'insert', 'remove', 'register']:
9296
result = {"errorCode": 102,
93-
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
97+
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
9498
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
9599
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
96100
return json.dumps(result)
@@ -173,6 +177,17 @@ def user_backend():
173177
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
174178
return json.dumps(result)
175179

180+
if request_type == 'register':
181+
# iat_type = param_dict.get("iat_type")
182+
response = adapter.ChatCompletion.create_register(
183+
model=model
184+
)
185+
if response in ['create_success', 'already_exists']:
186+
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
187+
else:
188+
result = {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
189+
return json.dumps(result)
190+
176191

177192
if __name__ == '__main__':
178193
app.run(host='0.0.0.0', port=5000, debug=True)

modelcache/adapter/adapter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from modelcache.adapter.adapter_query import adapt_query
66
from modelcache.adapter.adapter_insert import adapt_insert
77
from modelcache.adapter.adapter_remove import adapt_remove
8+
from modelcache.adapter.adapter_register import adapt_register
89

910

1011
class ChatCompletion(openai.ChatCompletion):
@@ -44,6 +45,16 @@ def create_remove(cls, *args, **kwargs):
4445
logging.info('adapt_remove_e: {}'.format(e))
4546
return str(e)
4647

48+
@classmethod
49+
def create_register(cls, *args, **kwargs):
50+
try:
51+
return adapt_register(
52+
*args,
53+
**kwargs
54+
)
55+
except Exception as e:
56+
return str(e)
57+
4758

4859
def construct_resp_from_cache(return_message, return_query):
4960
return {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache import cache
3+
4+
5+
def adapt_register(*args, **kwargs):
6+
chat_cache = kwargs.pop("cache_obj", cache)
7+
model = kwargs.pop("model", None)
8+
if model is None or len(model) == 0:
9+
return ValueError('')
10+
11+
register_resp = chat_cache.data_manager.create_index(model)
12+
print('register_resp: {}'.format(register_resp))
13+
return register_resp

modelcache/manager/data_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ def delete(self, id_list, **kwargs):
256256
return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count),
257257
'mysql': 'delete_count: '+str(s_delete_count)}
258258

259+
def create_index(self, model, **kwargs):
260+
return self.v.create(model)
261+
259262
def truncate(self, model_name):
260263
# model = kwargs.pop("model", None)
261264
# drop milvus data

modelcache/manager/vector_data/manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def get(name, **kwargs):
6868
local_mode=local_mode,
6969
local_data=local_data
7070
)
71+
elif name == "redis":
72+
from modelcache.manager.vector_data.redis import RedisVectorStore
73+
dimension = kwargs.get("dimension", DIMENSION)
74+
VectorBase.check_dimension(dimension)
75+
76+
redis_config = kwargs.get("redis_config")
77+
host = redis_config.get('redis', 'host')
78+
port = redis_config.get('redis', 'port')
79+
user = redis_config.get('redis', 'user')
80+
password = redis_config.get('redis', 'password')
81+
namespace = kwargs.get("namespace", "")
82+
# collection_name = kwargs.get("collection_name", COLLECTION_NAME)
83+
84+
vector_base = RedisVectorStore(
85+
host=host,
86+
port=port,
87+
username=user,
88+
password=password,
89+
namespace=namespace,
90+
top_k=top_k,
91+
dimension=dimension,
92+
)
7193
elif name == "faiss":
7294
from modelcache.manager.vector_data.faiss import Faiss
7395

modelcache/manager/vector_data/redis.py

Lines changed: 26 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from redis.commands.search.field import TagField, VectorField, NumericField
77
from redis.client import Redis
88

9-
from gptcache.manager.vector_data.base import VectorBase, VectorData
10-
from gptcache.utils import import_redis
11-
from gptcache.utils.log import gptcache_log
12-
from gptcache.utils.collection_util import get_collection_name
13-
from gptcache.utils.collection_util import get_collection_prefix
9+
from modelcache.manager.vector_data.base import VectorBase, VectorData
10+
from modelcache.utils import import_redis
11+
from modelcache.utils.log import modelcache_log
12+
from modelcache.utils.index_util import get_index_name
13+
from modelcache.utils.index_util import get_index_prefix
1414
import_redis()
1515

1616

@@ -21,9 +21,7 @@ def __init__(
2121
port: str = "6379",
2222
username: str = "",
2323
password: str = "",
24-
table_suffix: str = "",
2524
dimension: int = 0,
26-
collection_prefix: str = "gptcache",
2725
top_k: int = 1,
2826
namespace: str = "",
2927
):
@@ -36,33 +34,28 @@ def __init__(
3634
)
3735
self.top_k = top_k
3836
self.dimension = dimension
39-
self.collection_prefix = collection_prefix
40-
self.table_suffix = table_suffix
4137
self.namespace = namespace
42-
self.doc_prefix = f"{self.namespace}doc:" # Prefix with the specified namespace
43-
# self._create_collection(collection_name)
38+
self.doc_prefix = f"{self.namespace}doc:"
4439

4540
def _check_index_exists(self, index_name: str) -> bool:
4641
"""Check if Redis index exists."""
4742
try:
4843
self._client.ft(index_name).info()
49-
except: # pylint: disable=W0702
50-
gptcache_log.info("Index does not exist")
44+
except:
45+
modelcache_log.info("Index does not exist")
5146
return False
52-
gptcache_log.info("Index already exists")
47+
modelcache_log.info("Index already exists")
5348
return True
5449

55-
def create_collection(self, collection_name, index_prefix):
50+
def create_index(self, index_name, index_prefix):
5651
dimension = self.dimension
5752
print('dimension: {}'.format(dimension))
58-
if self._check_index_exists(collection_name):
59-
gptcache_log.info(
60-
"The %s already exists, and it will be used directly", collection_name
53+
if self._check_index_exists(index_name):
54+
modelcache_log.info(
55+
"The %s already exists, and it will be used directly", index_name
6156
)
6257
return 'already_exists'
6358
else:
64-
# id_field_name = collection_name + '_' + "id"
65-
# embedding_field_name = collection_name + '_' + "vec"
6659
id_field_name = "data_id"
6760
embedding_field_name = "data_vector"
6861

@@ -76,11 +69,10 @@ def create_collection(self, collection_name, index_prefix):
7669
}
7770
)
7871
fields = [id, embedding]
79-
# definition = IndexDefinition(index_type=IndexType.HASH)
8072
definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)
8173

8274
# create Index
83-
self._client.ft(collection_name).create_index(
75+
self._client.ft(index_name).create_index(
8476
fields=fields, definition=definition
8577
)
8678
return 'create_success'
@@ -90,23 +82,14 @@ def mul_add(self, datas: List[VectorData], model=None):
9082
for data in datas:
9183
id: int = data.id
9284
embedding = data.data.astype(np.float32).tobytes()
93-
# id_field_name = collection_name + '_' + "id"
94-
# embedding_field_name = collection_name + '_' + "vec"
9585
id_field_name = "data_id"
9686
embedding_field_name = "data_vector"
9787
obj = {id_field_name: id, embedding_field_name: embedding}
98-
index_prefix = get_collection_prefix(model, self.table_suffix)
88+
index_prefix = get_index_prefix(model)
9989
self._client.hset(f"{index_prefix}{id}", mapping=obj)
10090

101-
# obj = {
102-
# "vector": data.data.astype(np.float32).tobytes(),
103-
# }
104-
# pipe.hset(f"{self.doc_prefix}{key}", mapping=obj)
105-
# pipe.execute()
106-
10791
def search(self, data: np.ndarray, top_k: int = -1, model=None):
108-
collection_name = get_collection_name(model, self.table_suffix)
109-
print('collection_name: {}'.format(collection_name))
92+
index_name = get_index_name(model)
11093
id_field_name = "data_id"
11194
embedding_field_name = "data_vector"
11295

@@ -119,63 +102,40 @@ def search(self, data: np.ndarray, top_k: int = -1, model=None):
119102
)
120103

121104
query_params = {"vector": data.astype(np.float32).tobytes()}
122-
# print('query_params: {}'.format(query_params))
123105
results = (
124-
self._client.ft(collection_name)
106+
self._client.ft(index_name)
125107
.search(query, query_params=query_params)
126108
.docs
127109
)
128-
print('results: {}'.format(results))
129-
for i, doc in enumerate(results):
130-
print('doc: {}'.format(doc))
131-
print("id_field_name", getattr(doc, id_field_name), ", distance: ", doc.distance)
132110
return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]
133111

134112
def rebuild(self, ids=None) -> bool:
135113
pass
136114

137115
def rebuild_col(self, model):
138-
resp_info = 'failed'
139-
if len(self.table_suffix) == 0:
140-
raise ValueError('table_suffix is none error,please check!')
141-
142-
collection_name_model = get_collection_name(model, self.table_suffix)
143-
print('collection_name_model: {}'.format(collection_name_model))
144-
if self._check_index_exists(collection_name_model):
116+
index_name_model = get_index_name(model)
117+
if self._check_index_exists(index_name_model):
145118
try:
146-
self._client.ft(collection_name_model).dropindex(delete_documents=True)
119+
self._client.ft(index_name_model).dropindex(delete_documents=True)
147120
except Exception as e:
148121
raise ValueError(str(e))
149122
try:
150-
index_prefix = get_collection_prefix(model, self.table_suffix)
151-
self.create_collection(collection_name_model, index_prefix)
123+
index_prefix = get_index_prefix(model)
124+
self.create_index(index_name_model, index_prefix)
152125
except Exception as e:
153126
raise ValueError(str(e))
154127
return 'rebuild success'
155128

156-
# print('remove collection_name_model: {}'.format(collection_name_model))
157-
# try:
158-
# self._client.ft(collection_name_model).dropindex(delete_documents=True)
159-
# resp_info = 'rebuild success'
160-
# except Exception as e:
161-
# print('exception: {}'.format(e))
162-
# resp_info = 'create only'
163-
# try:
164-
# self.create_collection(collection_name_model)
165-
# except Exception as e:
166-
# raise ValueError(str(e))
167-
# return resp_info
168-
169129
def delete(self, ids) -> None:
170130
pipe = self._client.pipeline()
171131
for data_id in ids:
172132
pipe.delete(f"{self.doc_prefix}{data_id}")
173133
pipe.execute()
174134

175135
def create(self, model=None):
176-
collection_name = get_collection_name(model, self.table_suffix)
177-
index_prefix = get_collection_prefix(model, self.table_suffix)
178-
return self.create_collection(collection_name, index_prefix)
136+
index_name = get_index_name(model)
137+
index_prefix = get_index_prefix(model)
138+
return self.create_index(index_name, index_prefix)
179139

180-
def get_collection_by_name(self, collection_name, table_suffix):
140+
def get_index_by_name(self, index_name):
181141
pass

modelcache/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ def import_timm():
6969

7070
def import_pillow():
7171
_check_library("PIL", package="pillow")
72+
73+
74+
def import_redis():
75+
_check_library("redis")

modelcache/utils/index_util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
def get_index_name(model):
5+
return 'modelcache' + '_' + model
6+
7+
8+
def get_index_prefix(model):
9+
return 'prefix' + '_' + model

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ Requests==2.31.0
1010
torch==2.1.0
1111
transformers==4.34.1
1212
faiss-cpu==1.7.4
13+
redis==5.0.1
14+

0 commit comments

Comments
 (0)