Commit ea2945c2 authored by jyong's avatar jyong

knowledge add update dataset embedding model

parent a267969f
......@@ -186,6 +186,10 @@ class DatasetApi(Resource):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()
......
......@@ -169,9 +169,36 @@ class DatasetService:
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
......
......@@ -64,6 +64,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
elif action == 'update':
# clean index
index_processor.clean(dataset, None, with_keywords=False)
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
# add new index
if dataset_documents:
documents = []
for dataset_document in dataset_documents:
# delete from vector index
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
).order_by(DocumentSegment.position.asc()).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
end_at = time.perf_counter()
logging.info(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment