Unverified Commit 409e0c8e authored by Jyong's avatar Jyong Committed by GitHub

update qdrant migrate command (#2260)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 7076d41b
...@@ -339,26 +339,7 @@ def create_qdrant_indexes(): ...@@ -339,26 +339,7 @@ def create_qdrant_indexes():
) )
except Exception: except Exception:
try: continue
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
...@@ -405,7 +386,7 @@ def update_qdrant_indexes(): ...@@ -405,7 +386,7 @@ def update_qdrant_indexes():
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict: if dataset.index_struct_dict:
...@@ -413,23 +394,15 @@ def update_qdrant_indexes(): ...@@ -413,23 +394,15 @@ def update_qdrant_indexes():
try: try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id)) click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try: try:
embedding_model = ModelFactory.get_embedding_model( embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except Exception: except Exception:
provider = Provider( continue
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
...@@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: ...@@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count:
try: try:
click.echo('restore dataset index: {}'.format(dataset.id)) click.echo('restore dataset index: {}'.format(dataset.id))
try: try:
embedding_model = ModelFactory.get_embedding_model( model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except Exception: except Exception:
provider = Provider( pass
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment