Unverified Commit 86ad1d8e authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/agent-image

parents 60e625cc c97b7f67
...@@ -7,7 +7,7 @@ LABEL maintainer="takatost@gmail.com" ...@@ -7,7 +7,7 @@ LABEL maintainer="takatost@gmail.com"
FROM base as packages FROM base as packages
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
COPY requirements.txt /requirements.txt COPY requirements.txt /requirements.txt
...@@ -32,7 +32,7 @@ ENV TZ UTC ...@@ -32,7 +32,7 @@ ENV TZ UTC
WORKDIR /app/api WORKDIR /app/api
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y --no-install-recommends curl wget vim nodejs ffmpeg \ && apt-get install -y --no-install-recommends curl wget vim nodejs ffmpeg libgmp-dev libmpfr-dev libmpc-dev \
&& apt-get autoremove \ && apt-get autoremove \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
......
...@@ -339,26 +339,7 @@ def create_qdrant_indexes(): ...@@ -339,26 +339,7 @@ def create_qdrant_indexes():
) )
except Exception: except Exception:
try: continue
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
...@@ -405,7 +386,7 @@ def update_qdrant_indexes(): ...@@ -405,7 +386,7 @@ def update_qdrant_indexes():
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict: if dataset.index_struct_dict:
...@@ -413,23 +394,15 @@ def update_qdrant_indexes(): ...@@ -413,23 +394,15 @@ def update_qdrant_indexes():
try: try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id)) click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try: try:
embedding_model = ModelFactory.get_embedding_model( embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except Exception: except Exception:
provider = Provider( continue
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
...@@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: ...@@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count:
try: try:
click.echo('restore dataset index: {}'.format(dataset.id)) click.echo('restore dataset index: {}'.format(dataset.id))
try: try:
embedding_model = ModelFactory.get_embedding_model( model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
) )
except Exception: except Exception:
provider = Provider( pass
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
......
...@@ -40,17 +40,11 @@ DEFAULTS = { ...@@ -40,17 +40,11 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False', 'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False', 'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_OPENAI_PAID_MIN_QUANTITY': 1,
'HOSTED_OPENAI_PAID_MAX_QUANTITY': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False', 'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 1,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 1,
'HOSTED_MODERATION_ENABLED': 'False', 'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '', 'HOSTED_MODERATION_PROVIDERS': '',
'CLEAN_DAY_SETTING': 30, 'CLEAN_DAY_SETTING': 30,
...@@ -93,7 +87,7 @@ class Config: ...@@ -93,7 +87,7 @@ class Config:
# ------------------------ # ------------------------
# General Configurations. # General Configurations.
# ------------------------ # ------------------------
self.CURRENT_VERSION = "0.5.0" self.CURRENT_VERSION = "0.5.2"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')
...@@ -262,10 +256,6 @@ class Config: ...@@ -262,10 +256,6 @@ class Config:
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED') self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_OPENAI_PAID_MIN_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MIN_QUANTITY'))
self.HOSTED_OPENAI_PAID_MAX_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MAX_QUANTITY'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
...@@ -277,10 +267,6 @@ class Config: ...@@ -277,10 +267,6 @@ class Config:
self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED') self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED') self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED')
self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED') self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED')
......
...@@ -19,5 +19,3 @@ from .explore import audio, completion, conversation, installed_app, message, pa ...@@ -19,5 +19,3 @@ from .explore import audio, completion, conversation, installed_app, message, pa
from .workspace import account, members, model_providers, models, tool_providers, workspace from .workspace import account, members, model_providers, models, tool_providers, workspace
# Import billing controllers # Import billing controllers
from .billing import billing from .billing import billing
# Import operation controllers
from .operation import operation
...@@ -61,9 +61,7 @@ class BaseApiKeyListResource(Resource): ...@@ -61,9 +61,7 @@ class BaseApiKeyListResource(Resource):
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, _get_resource(resource_id, current_user.current_tenant_id,
self.resource_model) self.resource_model)
if not current_user.is_admin_or_owner:
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
current_key_count = db.session.query(ApiToken). \ current_key_count = db.session.query(ApiToken). \
...@@ -102,7 +100,7 @@ class BaseApiKeyResource(Resource): ...@@ -102,7 +100,7 @@ class BaseApiKeyResource(Resource):
self.resource_model) self.resource_model)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
key = db.session.query(ApiToken). \ key = db.session.query(ApiToken). \
......
...@@ -21,7 +21,7 @@ class AnnotationReplyActionApi(Resource): ...@@ -21,7 +21,7 @@ class AnnotationReplyActionApi(Resource):
@cloud_edition_billing_resource_check('annotation') @cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action): def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -45,7 +45,7 @@ class AppAnnotationSettingDetailApi(Resource): ...@@ -45,7 +45,7 @@ class AppAnnotationSettingDetailApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -59,7 +59,7 @@ class AppAnnotationSettingUpdateApi(Resource): ...@@ -59,7 +59,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id, annotation_setting_id): def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -80,7 +80,7 @@ class AnnotationReplyActionStatusApi(Resource): ...@@ -80,7 +80,7 @@ class AnnotationReplyActionStatusApi(Resource):
@cloud_edition_billing_resource_check('annotation') @cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
...@@ -108,7 +108,7 @@ class AnnotationListApi(Resource): ...@@ -108,7 +108,7 @@ class AnnotationListApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
page = request.args.get('page', default=1, type=int) page = request.args.get('page', default=1, type=int)
...@@ -133,7 +133,7 @@ class AnnotationExportApi(Resource): ...@@ -133,7 +133,7 @@ class AnnotationExportApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -152,7 +152,7 @@ class AnnotationCreateApi(Resource): ...@@ -152,7 +152,7 @@ class AnnotationCreateApi(Resource):
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -172,7 +172,7 @@ class AnnotationUpdateDeleteApi(Resource): ...@@ -172,7 +172,7 @@ class AnnotationUpdateDeleteApi(Resource):
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource): ...@@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id, annotation_id): def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -205,7 +205,7 @@ class AnnotationBatchImportApi(Resource): ...@@ -205,7 +205,7 @@ class AnnotationBatchImportApi(Resource):
@cloud_edition_billing_resource_check('annotation') @cloud_edition_billing_resource_check('annotation')
def post(self, app_id): def post(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
...@@ -230,7 +230,7 @@ class AnnotationBatchImportStatusApi(Resource): ...@@ -230,7 +230,7 @@ class AnnotationBatchImportStatusApi(Resource):
@cloud_edition_billing_resource_check('annotation') @cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id): def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
...@@ -257,7 +257,7 @@ class AnnotationHitHistoryListApi(Resource): ...@@ -257,7 +257,7 @@ class AnnotationHitHistoryListApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id, annotation_id): def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
page = request.args.get('page', default=1, type=int) page = request.args.get('page', default=1, type=int)
......
...@@ -88,7 +88,7 @@ class AppListApi(Resource): ...@@ -88,7 +88,7 @@ class AppListApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
...@@ -237,7 +237,7 @@ class AppApi(Resource): ...@@ -237,7 +237,7 @@ class AppApi(Resource):
"""Delete app""" """Delete app"""
app_id = str(app_id) app_id = str(app_id)
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app = _get_app(app_id, current_user.current_tenant_id) app = _get_app(app_id, current_user.current_tenant_id)
......
...@@ -157,7 +157,7 @@ class MessageAnnotationApi(Resource): ...@@ -157,7 +157,7 @@ class MessageAnnotationApi(Resource):
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
......
...@@ -42,7 +42,7 @@ class AppSite(Resource): ...@@ -42,7 +42,7 @@ class AppSite(Resource):
app_model = _get_app(app_id) app_model = _get_app(app_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
site = db.session.query(Site). \ site = db.session.query(Site). \
...@@ -88,7 +88,7 @@ class AppSiteAccessTokenReset(Resource): ...@@ -88,7 +88,7 @@ class AppSiteAccessTokenReset(Resource):
app_model = _get_app(app_id) app_model = _get_app(app_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first() site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
......
...@@ -30,7 +30,7 @@ def get_oauth_providers(): ...@@ -30,7 +30,7 @@ def get_oauth_providers():
class OAuthDataSource(Resource): class OAuthDataSource(Resource):
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
......
...@@ -20,7 +20,7 @@ class Subscription(Resource): ...@@ -20,7 +20,7 @@ class Subscription(Resource):
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args() args = parser.parse_args()
BillingService.is_tenant_owner(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args['plan'], return BillingService.get_subscription(args['plan'],
args['interval'], args['interval'],
...@@ -35,8 +35,8 @@ class Invoices(Resource): ...@@ -35,8 +35,8 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
BillingService.is_tenant_owner(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, '/billing/subscription') api.add_resource(Subscription, '/billing/subscription')
......
...@@ -103,7 +103,7 @@ class DatasetListApi(Resource): ...@@ -103,7 +103,7 @@ class DatasetListApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
...@@ -187,7 +187,7 @@ class DatasetApi(Resource): ...@@ -187,7 +187,7 @@ class DatasetApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
dataset = DatasetService.update_dataset( dataset = DatasetService.update_dataset(
...@@ -205,7 +205,7 @@ class DatasetApi(Resource): ...@@ -205,7 +205,7 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if DatasetService.delete_dataset(dataset_id_str, current_user): if DatasetService.delete_dataset(dataset_id_str, current_user):
...@@ -391,7 +391,7 @@ class DatasetApiKeyApi(Resource): ...@@ -391,7 +391,7 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = db.session.query(ApiToken). \ current_key_count = db.session.query(ApiToken). \
...@@ -425,7 +425,7 @@ class DatasetApiDeleteApi(Resource): ...@@ -425,7 +425,7 @@ class DatasetApiDeleteApi(Resource):
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
key = db.session.query(ApiToken). \ key = db.session.query(ApiToken). \
......
...@@ -204,7 +204,7 @@ class DatasetDocumentListApi(Resource): ...@@ -204,7 +204,7 @@ class DatasetDocumentListApi(Resource):
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
...@@ -256,7 +256,7 @@ class DatasetInitApi(Resource): ...@@ -256,7 +256,7 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_resource_check('vector_space') @cloud_edition_billing_resource_check('vector_space')
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
...@@ -599,7 +599,7 @@ class DocumentProcessingApi(DocumentResource): ...@@ -599,7 +599,7 @@ class DocumentProcessingApi(DocumentResource):
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if action == "pause": if action == "pause":
...@@ -663,7 +663,7 @@ class DocumentMetadataApi(DocumentResource): ...@@ -663,7 +663,7 @@ class DocumentMetadataApi(DocumentResource):
doc_metadata = req_data.get('doc_metadata') doc_metadata = req_data.get('doc_metadata')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if doc_type is None or doc_metadata is None: if doc_type is None or doc_metadata is None:
...@@ -710,7 +710,7 @@ class DocumentStatusApi(DocumentResource): ...@@ -710,7 +710,7 @@ class DocumentStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
indexing_cache_key = 'document_{}_indexing'.format(document.id) indexing_cache_key = 'document_{}_indexing'.format(document.id)
......
...@@ -123,7 +123,7 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -123,7 +123,7 @@ class DatasetDocumentSegmentApi(Resource):
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
...@@ -219,7 +219,7 @@ class DatasetDocumentSegmentAddApi(Resource): ...@@ -219,7 +219,7 @@ class DatasetDocumentSegmentAddApi(Resource):
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
...@@ -298,7 +298,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -298,7 +298,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if not segment: if not segment:
raise NotFound('Segment not found.') raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
...@@ -342,7 +342,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -342,7 +342,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if not segment: if not segment:
raise NotFound('Segment not found.') raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
......
...@@ -3,10 +3,12 @@ from flask_restful import Resource ...@@ -3,10 +3,12 @@ from flask_restful import Resource
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api from . import api
from .wraps import cloud_utm_record
class FeatureApi(Resource): class FeatureApi(Resource):
@cloud_utm_record
def get(self): def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict() return FeatureService.get_features(current_user.current_tenant_id).dict()
......
from flask_login import current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, only_edition_cloud
from libs.login import login_required
from services.operation_service import OperationService
class TenantUtm(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('utm_source', type=str, required=True)
parser.add_argument('utm_medium', type=str, required=True)
parser.add_argument('utm_campaign', type=str, required=False, default='')
parser.add_argument('utm_content', type=str, required=False, default='')
parser.add_argument('utm_term', type=str, required=False, default='')
args = parser.parse_args()
return OperationService.record_utm(current_user.current_tenant_id, args)
api.add_resource(TenantUtm, '/operation/utm')
...@@ -52,10 +52,12 @@ class MemberInviteEmailApi(Resource): ...@@ -52,10 +52,12 @@ class MemberInviteEmailApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('emails', type=str, required=True, location='json', action='append') parser.add_argument('emails', type=str, required=True, location='json', action='append')
parser.add_argument('role', type=str, required=True, default='admin', location='json') parser.add_argument('role', type=str, required=True, default='admin', location='json')
parser.add_argument('language', type=str, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
invitee_emails = args['emails'] invitee_emails = args['emails']
invitee_role = args['role'] invitee_role = args['role']
interface_language = args['language']
if invitee_role not in ['admin', 'normal']: if invitee_role not in ['admin', 'normal']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
...@@ -64,8 +66,7 @@ class MemberInviteEmailApi(Resource): ...@@ -64,8 +66,7 @@ class MemberInviteEmailApi(Resource):
console_web_url = current_app.config.get("CONSOLE_WEB_URL") console_web_url = current_app.config.get("CONSOLE_WEB_URL")
for invitee_email in invitee_emails: for invitee_email in invitee_emails:
try: try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
inviter=inviter)
invitation_results.append({ invitation_results.append({
'status': 'success', 'status': 'success',
'email': invitee_email, 'email': invitee_email,
......
...@@ -98,7 +98,7 @@ class ModelProviderApi(Resource): ...@@ -98,7 +98,7 @@ class ModelProviderApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
...@@ -122,7 +122,7 @@ class ModelProviderApi(Resource): ...@@ -122,7 +122,7 @@ class ModelProviderApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
...@@ -159,7 +159,7 @@ class PreferredProviderTypeUpdateApi(Resource): ...@@ -159,7 +159,7 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
...@@ -186,10 +186,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): ...@@ -186,10 +186,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str): def get(self, provider: str):
if provider != 'anthropic': if provider != 'anthropic':
raise ValueError(f'provider name {provider} is invalid') raise ValueError(f'provider name {provider} is invalid')
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(provider_name=provider, data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account_id=current_user.id) account_id=current_user.id,
prefilled_email=current_user.email)
return data return data
......
...@@ -43,7 +43,7 @@ class ToolBuiltinProviderDeleteApi(Resource): ...@@ -43,7 +43,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = current_user.id
...@@ -60,7 +60,7 @@ class ToolBuiltinProviderUpdateApi(Resource): ...@@ -60,7 +60,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = current_user.id
...@@ -114,7 +114,7 @@ class ToolApiProviderAddApi(Resource): ...@@ -114,7 +114,7 @@ class ToolApiProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = current_user.id
...@@ -183,7 +183,7 @@ class ToolApiProviderUpdateApi(Resource): ...@@ -183,7 +183,7 @@ class ToolApiProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = current_user.id
...@@ -217,7 +217,7 @@ class ToolApiProviderDeleteApi(Resource): ...@@ -217,7 +217,7 @@ class ToolApiProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = current_user.id
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import json
from functools import wraps from functools import wraps
from flask import request
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from flask import abort, current_app from flask import abort, current_app
from flask_login import current_user from flask_login import current_user
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.operation_service import OperationService
def account_initialization_required(view): def account_initialization_required(view):
...@@ -73,3 +76,20 @@ def cloud_edition_billing_resource_check(resource: str, ...@@ -73,3 +76,20 @@ def cloud_edition_billing_resource_check(resource: str,
return decorated return decorated
return interceptor return interceptor
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
try:
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get('utm_info')
if utm_info:
utm_info = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info)
except Exception as e:
pass
return view(*args, **kwargs)
return decorated
...@@ -13,7 +13,7 @@ from core.application_queue_manager import ApplicationQueueManager ...@@ -13,7 +13,7 @@ from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from flask import Response, stream_with_context, request from flask import Response, stream_with_context
from flask_restful import reqparse from flask_restful import reqparse
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService from services.completion_service import CompletionService
...@@ -75,18 +75,22 @@ class CompletionApi(AppApiResource): ...@@ -75,18 +75,22 @@ class CompletionApi(AppApiResource):
class CompletionStopApi(AppApiResource): class CompletionStopApi(AppApiResource):
def post(self, app_model, _, task_id): def post(self, app_model, end_user, task_id):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
parser = reqparse.RequestParser() if end_user is None:
parser.add_argument('user', required=True, nullable=False, type=str, location='json') parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
args = parser.parse_args() user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
end_user_id = args.get('user') ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -146,13 +150,22 @@ class ChatApi(AppApiResource): ...@@ -146,13 +150,22 @@ class ChatApi(AppApiResource):
class ChatStopApi(AppApiResource): class ChatStopApi(AppApiResource):
def post(self, app_model, _, task_id): def post(self, app_model, end_user, task_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
end_user_id = request.get_json().get('user') if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -75,8 +75,8 @@ def validate_dataset_token(view=None): ...@@ -75,8 +75,8 @@ def validate_dataset_token(view=None):
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
.filter(Tenant.id == api_token.tenant_id) \ .filter(Tenant.id == api_token.tenant_id) \
.filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \
.filter(TenantAccountJoin.role.in_(['owner', 'admin'])) \ .filter(TenantAccountJoin.role.in_(['owner'])) \
.one_or_none() .one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first() account = Account.query.filter_by(id=ta.account_id).first()
...@@ -86,9 +86,9 @@ def validate_dataset_token(view=None): ...@@ -86,9 +86,9 @@ def validate_dataset_token(view=None):
current_app.login_manager._update_request_context_with_user(account) current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user()) user_logged_in.send(current_app._get_current_object(), user=_get_user())
else: else:
raise Unauthorized("Tenant owner account is not exist.") raise Unauthorized("Tenant owner account does not exist.")
else: else:
raise Unauthorized("Tenant is not exist.") raise Unauthorized("Tenant does not exist.")
return view(api_token.tenant_id, *args, **kwargs) return view(api_token.tenant_id, *args, **kwargs)
return decorated return decorated
......
...@@ -9,6 +9,7 @@ from pydantic import BaseModel ...@@ -9,6 +9,7 @@ from pydantic import BaseModel
class QuotaUnit(Enum): class QuotaUnit(Enum):
TIMES = 'times' TIMES = 'times'
TOKENS = 'tokens' TOKENS = 'tokens'
CREDITS = 'credits'
class SystemConfigurationStatus(Enum): class SystemConfigurationStatus(Enum):
......
...@@ -20,10 +20,6 @@ class TrialHostingQuota(HostingQuota): ...@@ -20,10 +20,6 @@ class TrialHostingQuota(HostingQuota):
class PaidHostingQuota(HostingQuota): class PaidHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.PAID quota_type: ProviderQuotaType = ProviderQuotaType.PAID
stripe_price_id: str = None
increase_quota: int = 1
min_quantity: int = 20
max_quantity: int = 100
class FreeHostingQuota(HostingQuota): class FreeHostingQuota(HostingQuota):
...@@ -102,7 +98,7 @@ class HostingConfiguration: ...@@ -102,7 +98,7 @@ class HostingConfiguration:
) )
def init_openai(self, app_config: Config) -> HostingProvider: def init_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TIMES quota_unit = QuotaUnit.CREDITS
quotas = [] quotas = []
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
...@@ -114,6 +110,8 @@ class HostingConfiguration: ...@@ -114,6 +110,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="whisper-1", model_type=ModelType.SPEECH2TEXT), RestrictModel(model="whisper-1", model_type=ModelType.SPEECH2TEXT),
] ]
...@@ -122,10 +120,20 @@ class HostingConfiguration: ...@@ -122,10 +120,20 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_quota = PaidHostingQuota( paid_quota = PaidHostingQuota(
stripe_price_id=app_config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), restrict_models=[
increase_quota=int(app_config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")), RestrictModel(model="gpt-4", model_type=ModelType.LLM),
min_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")), RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
max_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1")) RestrictModel(model="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
) )
quotas.append(paid_quota) quotas.append(paid_quota)
...@@ -164,12 +172,7 @@ class HostingConfiguration: ...@@ -164,12 +172,7 @@ class HostingConfiguration:
quotas.append(trial_quota) quotas.append(trial_quota)
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
paid_quota = PaidHostingQuota( paid_quota = PaidHostingQuota()
stripe_price_id=app_config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
increase_quota=int(app_config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
min_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
max_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
)
quotas.append(paid_quota) quotas.append(paid_quota)
if len(quotas) > 0: if len(quotas) > 0:
......
...@@ -562,7 +562,7 @@ class IndexingRunner: ...@@ -562,7 +562,7 @@ class IndexingRunner:
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=segmentation["max_tokens"],
chunk_overlap=0, chunk_overlap=segmentation.get('chunk_overlap', 0),
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "。", ".", " ", ""], separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
...@@ -571,7 +571,7 @@ class IndexingRunner: ...@@ -571,7 +571,7 @@ class IndexingRunner:
# Automatic segmentation # Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0, chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "。", ".", " ", ""], separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )
......
...@@ -8,9 +8,9 @@ model_properties: ...@@ -8,9 +8,9 @@ model_properties:
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
- name: topP - name: top_p
use_template: top_p use_template: top_p
- name: topK - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top K en_US: Top K
......
...@@ -8,9 +8,9 @@ model_properties: ...@@ -8,9 +8,9 @@ model_properties:
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
- name: topP - name: top_p
use_template: top_p use_template: top_p
- name: topK - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top K en_US: Top K
......
...@@ -8,9 +8,9 @@ model_properties: ...@@ -8,9 +8,9 @@ model_properties:
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature
- name: topP - name: top_p
use_template: top_p use_template: top_p
- name: topK - name: top_k
label: label:
zh_Hans: 取样数量 zh_Hans: 取样数量
en_US: Top K en_US: Top K
......
...@@ -250,9 +250,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ...@@ -250,9 +250,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
invoke = runtime_client.invoke_model invoke = runtime_client.invoke_model
try: try:
body_jsonstr=json.dumps(payload)
response = invoke( response = invoke(
body=json.dumps(payload),
modelId=model, modelId=model,
contentType="application/json",
accept= "*/*",
body=body_jsonstr
) )
except ClientError as ex: except ClientError as ex:
error_code = ex.response['Error']['Code'] error_code = ex.response['Error']['Code']
...@@ -385,7 +388,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ...@@ -385,7 +388,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if not chunk: if not chunk:
exception_name = next(iter(event)) exception_name = next(iter(event))
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
raise self._map_client_to_invoke_error(exception_name, full_ex_msg) raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
payload = json.loads(chunk.get('bytes').decode()) payload = json.loads(chunk.get('bytes').decode())
...@@ -396,7 +398,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ...@@ -396,7 +398,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
finish_reason = payload.get("completion_reason") finish_reason = payload.get("completion_reason")
elif model_prefix == "anthropic": elif model_prefix == "anthropic":
content_delta = payload content_delta = payload.get("completion")
finish_reason = payload.get("stop_reason") finish_reason = payload.get("stop_reason")
elif model_prefix == "cohere": elif model_prefix == "cohere":
...@@ -410,12 +412,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ...@@ -410,12 +412,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
else: else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
index += 1 # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content = content_delta if content_delta else '', content = content_delta if content_delta else '',
) )
index += 1
if not finish_reason: if not finish_reason:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
......
...@@ -11,7 +11,7 @@ help: ...@@ -11,7 +11,7 @@ help:
en_US: How to integrate with Ollama en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama zh_Hans: 如何集成 Ollama
url: url:
en_US: https://docs.dify.ai/advanced/model-configuration/ollama en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
supported_model_types: supported_model_types:
- llm - llm
- text-embedding - text-embedding
......
- gpt-4 - gpt-4
- gpt-4-turbo-preview
- gpt-4-32k - gpt-4-32k
- gpt-4-1106-preview - gpt-4-1106-preview
- gpt-4-0125-preview
- gpt-4-vision-preview - gpt-4-vision-preview
- gpt-3.5-turbo - gpt-3.5-turbo
- gpt-3.5-turbo-16k - gpt-3.5-turbo-16k
......
model: gpt-4-0125-preview
label:
zh_Hans: gpt-4-0125-preview
en_US: gpt-4-0125-preview
model_type: llm
features:
- multi-tool-call
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: seed
label:
zh_Hans: 种子
en_US: Seed
type: int
help:
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
响应参数来监视变化。
en_US: If specified, model will make a best effort to sample deterministically,
such that repeated requests with the same seed and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
system_fingerprint response parameter to monitor changes in the backend.
required: false
precision: 2
min: 0
max: 1
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.03'
unit: '0.001'
currency: USD
model: gpt-4-turbo-preview
label:
zh_Hans: gpt-4-turbo-preview
en_US: gpt-4-turbo-preview
model_type: llm
features:
- multi-tool-call
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: seed
label:
zh_Hans: 种子
en_US: Seed
type: int
help:
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
响应参数来监视变化。
en_US: If specified, model will make a best effort to sample deterministically,
such that repeated requests with the same seed and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
system_fingerprint response parameter to monitor changes in the backend.
required: false
precision: 2
min: 0
max: 1
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.03'
unit: '0.001'
currency: USD
...@@ -26,3 +26,4 @@ pricing: ...@@ -26,3 +26,4 @@ pricing:
output: '0.002' output: '0.002'
unit: '0.001' unit: '0.001'
currency: USD currency: USD
deprecated: true
model: text-embedding-3-large
model_type: text-embedding
model_properties:
context_size: 8191
max_chunks: 32
pricing:
input: '0.00013'
unit: '0.001'
currency: USD
model: text-embedding-3-small
model_type: text-embedding
model_properties:
context_size: 8191
max_chunks: 32
pricing:
input: '0.00002'
unit: '0.001'
currency: USD
...@@ -224,7 +224,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -224,7 +224,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else: else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}") raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
...@@ -343,31 +343,37 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -343,31 +343,37 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
) )
) )
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'): # delimiter for stream response, need unicode_escape
import codecs
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
if chunk: if chunk:
decoded_chunk = chunk.strip().lstrip('data: ').lstrip() decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)
# stream ended # stream ended
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}")
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index + 1, index=chunk_index + 1,
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered." finish_reason="Non-JSON encountered."
) )
break break
if not chunk_json or len(chunk_json['choices']) == 0: if not chunk_json or len(chunk_json['choices']) == 0:
continue continue
choice = chunk_json['choices'][0] choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
chunk_index += 1 chunk_index += 1
if 'delta' in choice: if 'delta' in choice:
delta = choice['delta'] delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '': delta_content = delta.get('content')
if delta_content is None or delta_content == '':
continue continue
assistant_message_tool_calls = delta.get('tool_calls', None) assistant_message_tool_calls = delta.get('tool_calls', None)
...@@ -381,30 +387,28 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -381,30 +387,28 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''), content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else [] tool_calls=tool_calls if assistant_message_tool_calls else []
) )
full_assistant_content += delta.get('content', '') full_assistant_content += delta_content
elif 'text' in choice: elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '': choice_text = choice.get('text', '')
if choice_text == '':
continue continue
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(content=choice_text)
content=choice.get('text', '') full_assistant_content += choice_text
)
full_assistant_content += choice.get('text', '')
else: else:
continue continue
# check payload indicator for completion # check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None: if finish_reason is not None:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index, index=chunk_index,
message=assistant_prompt_message, message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason'] finish_reason=finish_reason
) )
else: else:
yield LLMResultChunk( yield LLMResultChunk(
......
...@@ -75,3 +75,12 @@ model_credential_schema: ...@@ -75,3 +75,12 @@ model_credential_schema:
value: llm value: llm
default: '4096' default: '4096'
type: text-input type: text-input
- variable: stream_mode_delimiter
label:
zh_Hans: 流模式返回结果的分隔符
en_US: Delimiter for streaming results
show_on:
- variable: __model_type
value: llm
default: '\n\n'
type: text-input
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import logging
import posixpath
from pydantic import BaseModel, Extra
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)
...@@ -28,15 +28,3 @@ parameter_rules: ...@@ -28,15 +28,3 @@ parameter_rules:
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回,json_string 返回标准的 JSON 字符串,text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string
...@@ -28,15 +28,3 @@ parameter_rules: ...@@ -28,15 +28,3 @@ parameter_rules:
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回,json_string 返回标准的 JSON 字符串,text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string
...@@ -30,15 +30,3 @@ parameter_rules: ...@@ -30,15 +30,3 @@ parameter_rules:
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回,json_string 返回标准的 JSON 字符串,text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string
...@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType ...@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from langchain.schema.language_model import _get_token_ids_default_method from langchain.schema.language_model import _get_token_ids_default_method
...@@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): ...@@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result :return: embeddings result
""" """
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuModelAPI( client = ZhipuAI(
api_key=credentials_kwargs['api_key'] api_key=credentials_kwargs['api_key']
) )
...@@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): ...@@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try: try:
# transform credentials to kwargs for model instance # transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuModelAPI( client = ZhipuAI(
api_key=credentials_kwargs['api_key'] api_key=credentials_kwargs['api_key']
) )
...@@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): ...@@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]: def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
"""Call out to ZhipuAI's embedding endpoint. """Call out to ZhipuAI's embedding endpoint.
Args: Args:
...@@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): ...@@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = [] embeddings = []
for text in texts: embedding_used_tokens = 0
response = client.invoke(model=model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
embedding_used_tokens = data.get('usage') for text in texts:
response = client.embeddings.create(model=model, input=text)
data = response.data[0]
embeddings.append(data.embedding)
embedding_used_tokens += response.usage.total_tokens
return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0 return [list(map(float, e)) for e in embeddings], embedding_used_tokens
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint. """Call out to ZhipuAI's embedding endpoint.
......
from ._client import ZhipuAI
from .core._errors import (
ZhipuAIError,
APIStatusError,
APIRequestFailedError,
APIAuthenticationError,
APIReachLimitError,
APIInternalError,
APIServerFlowExceedError,
APIResponseError,
APIResponseValidationError,
APITimeoutError,
)
from .__version__ import __version__
__version__ = 'v2.0.1'
\ No newline at end of file
from __future__ import annotations
from typing import Union, Mapping
from typing_extensions import override
from .core import _jwt_token
from .core._errors import ZhipuAIError
from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
from .core._base_type import NotGiven, NOT_GIVEN
from . import api_resource
import os
import httpx
from httpx import Timeout
class ZhipuAI(HttpClient):
chat: api_resource.chat
api_key: str
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None
) -> None:
# if api_key is None:
# api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
self.api_key = api_key
if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
if base_url is None:
base_url = f"https://open.bigmodel.cn/api/paas/v4"
from .__version__ import __version__
super().__init__(
version=__version__,
base_url=base_url,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
@property
@override
def _auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
def __del__(self) -> None:
if (not hasattr(self, "_has_custom_http_client")
or not hasattr(self, "close")
or not hasattr(self, "_client")):
# if the '__init__' method raised an error, self would not have client attr
return
if self._has_custom_http_client:
return
self.close()
from .chat import chat
from .images import Images
from .embeddings import Embeddings
from .files import Files
from .fine_tuning import fine_tuning
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from typing_extensions import Literal
from ...core._base_api import BaseAPI
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
from ...core._http_client import make_user_request_input
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
if TYPE_CHECKING:
from ..._client import ZhipuAI
class AsyncCompletions(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, List[str], List[int], List[List[int]], None],
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncTaskStatus:
_cast_type = AsyncTaskStatus
if disable_strict_validation:
_cast_type = object
return self._post(
"/async/chat/completions",
body={
"model": model,
"request_id": request_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)
def retrieve_completion_result(
self,
id: str,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Union[AsyncCompletion, AsyncTaskStatus]:
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
if disable_strict_validation:
_cast_type = object
return self._get(
path=f"/async-result/{id}",
cast_type=_cast_type,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout
)
)
from typing import TYPE_CHECKING
from .completions import Completions
from .async_completions import AsyncCompletions
from ...core._base_api import BaseAPI
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Chat(BaseAPI):
completions: Completions
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.completions = Completions(client)
self.asyncCompletions = AsyncCompletions(client)
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from typing_extensions import Literal
from ...core._base_api import BaseAPI
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
from ...core._http_client import make_user_request_input
from ...core._sse_client import StreamResponse
from ...types.chat.chat_completion import Completion
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Completions(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, List[str], List[int], object, None],
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | StreamResponse[ChatCompletionChunk]:
_cast_type = Completion
_stream_cls = StreamResponse[ChatCompletionChunk]
if disable_strict_validation:
_cast_type = object
_stream_cls = StreamResponse[object]
return self._post(
"/chat/completions",
body={
"model": model,
"request_id": request_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"stream": stream,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(
extra_headers=extra_headers,
),
cast_type=_cast_type,
enable_stream=stream or False,
stream_cls=_stream_cls,
)
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
from ..core._http_client import make_user_request_input
from ..types.embeddings import EmbeddingsResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Embeddings(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
input: Union[str, List[str], List[int], List[List[int]]],
model: Union[str],
encoding_format: str | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EmbeddingsResponded:
_cast_type = EmbeddingsResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/embeddings",
body={
"input": input,
"model": model,
"encoding_format": encoding_format,
"user": user,
"sensitive_word_check": sensitive_word_check,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)
from __future__ import annotations
from typing import TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..core._files import is_file_content
from ..core._http_client import (
make_user_request_input,
)
from ..types.file_object import FileObject, ListOfFileObject
if TYPE_CHECKING:
from .._client import ZhipuAI
__all__ = ["Files"]
class Files(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
file: FileTypes,
purpose: str,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileObject:
if not is_file_content(file):
prefix = f"Expected file input `{file!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
) from None
files = [("file", file)]
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body={
"purpose": purpose,
},
files=files,
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FileObject,
)
def list(
self,
*,
purpose: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
after: str | NotGiven = NOT_GIVEN,
order: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFileObject:
return self._get(
"/files",
cast_type=ListOfFileObject,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"purpose": purpose,
"limit": limit,
"after": after,
"order": order,
},
),
)
from typing import TYPE_CHECKING
from .jobs import Jobs
from ...core._base_api import BaseAPI
if TYPE_CHECKING:
from ..._client import ZhipuAI
class FineTuning(BaseAPI):
jobs: Jobs
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.jobs = Jobs(client)
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
import httpx
from ...core._base_api import BaseAPI
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
from ...core._http_client import (
make_user_request_input,
)
from ...types.fine_tuning import (
FineTuningJob,
job_create_params,
ListOfFineTuningJob,
FineTuningJobEvent,
)
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Jobs"]
class Jobs(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
training_file: str,
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
suffix: Optional[str] | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._post(
"/fine_tuning/jobs",
body={
"model": model,
"training_file": training_file,
"hyperparameters": hyperparameters,
"suffix": suffix,
"validation_file": validation_file,
"request_id": request_id,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FineTuningJob,
)
def retrieve(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FineTuningJob,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFineTuningJob:
return self._get(
"/fine_tuning/jobs",
cast_type=ListOfFineTuningJob,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def list_events(
self,
fine_tuning_job_id: str,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJobEvent:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
cast_type=FineTuningJobEvent,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
from ..core._http_client import make_user_request_input
from ..types.image import ImagesResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Images(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def generations(
self,
*,
prompt: str,
model: str | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
quality: Optional[str] | NotGiven = NOT_GIVEN,
response_format: Optional[str] | NotGiven = NOT_GIVEN,
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
_cast_type = ImagesResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/images/generations",
body={
"prompt": prompt,
"model": model,
"n": n,
"quality": quality,
"response_format": response_format,
"size": size,
"style": style,
"user": user,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .._client import ZhipuAI
class BaseAPI:
_client: ZhipuAI
def __init__(self, client: ZhipuAI) -> None:
self._client = client
self._delete = client.delete
self._get = client.get
self._post = client.post
self._put = client.put
self._patch = client.patch
from __future__ import annotations
from os import PathLike
from typing import (
TYPE_CHECKING,
Type,
Union,
Mapping,
TypeVar, IO, Tuple, Sequence, Any, List,
)
import pydantic
from typing_extensions import (
Literal,
override,
)
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
PrimitiveData = Union[str, int, float, bool, None]
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
# Sentinel class used until PEP 0661 is accepted
class NotGiven(pydantic.BaseModel):
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit(pydantic.BaseModel):
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post('/upload/files', files={'file': b'my raw file content'})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={'Content-Type': 'multipart/form-data'})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={'Content-Type': Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
Headers = Mapping[str, Union[str, Omit]]
ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
)
# for user input files
if TYPE_CHECKING:
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
FileContent = Union[IO[bytes], bytes, PathLike]
FileTypes = Union[
FileContent, # file content
Tuple[str, FileContent], # (filename, file)
Tuple[str, FileContent, str], # (filename, file , content_type)
Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# for httpx client supported files
HttpxFileContent = Union[bytes, IO[bytes]]
HttpxFileTypes = Union[
FileContent, # file content
Tuple[str, HttpxFileContent], # (filename, file)
Tuple[str, HttpxFileContent, str], # (filename, file , content_type)
Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
from __future__ import annotations
import httpx
__all__ = [
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
]
class ZhipuAIError(Exception):
def __init__(self, message: str, ) -> None:
super().__init__(message)
class APIStatusError(Exception):
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response) -> None:
super().__init__(message)
self.response = response
self.status_code = response.status_code
class APIRequestFailedError(APIStatusError):
...
class APIAuthenticationError(APIStatusError):
...
class APIReachLimitError(APIStatusError):
...
class APIInternalError(APIStatusError):
...
class APIServerFlowExceedError(APIStatusError):
...
class APIResponseError(Exception):
message: str
request: httpx.Request
json_data: object
def __init__(self, message: str, request: httpx.Request, json_data: object):
self.message = message
self.request = request
self.json_data = json_data
super().__init__(message)
class APIResponseValidationError(APIResponseError):
status_code: int
response: httpx.Response
def __init__(
self,
response: httpx.Response,
json_data: object | None, *,
message: str | None = None
) -> None:
super().__init__(
message=message or "Data returned by API invalid for expected schema.",
request=response.request,
json_data=json_data
)
self.response = response
self.status_code = response.status_code
class APITimeoutError(Exception):
request: httpx.Request
def __init__(self, request: httpx.Request):
self.request = request
super().__init__("Request Timeout")
from __future__ import annotations
import io
import os
from pathlib import Path
from typing import Mapping, Sequence
from ._base_type import (
FileTypes,
HttpxFileTypes,
HttpxRequestFiles,
RequestFiles,
)
def is_file_content(obj: object) -> bool:
return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = Path(file)
return path.name, path.read_bytes()
else:
return file
if isinstance(file, tuple):
if isinstance(file[1], os.PathLike):
return (file[0], Path(file[1]).read_bytes(), *file[2:])
else:
return (file[0], file[1], *file[2:])
else:
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if isinstance(files, Mapping):
files = {key: _transform_file(file) for key, file in files.items()}
elif isinstance(files, Sequence):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
return files
# -*- coding:utf-8 -*-
import time
import cachetools.func
import jwt
API_TOKEN_TTL_SECONDS = 3 * 60
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
def generate_token(apikey: str):
try:
api_key, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid api_key", e)
payload = {
"api_key": api_key,
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
"timestamp": int(round(time.time() * 1000)),
}
ret = jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
return ret
from __future__ import annotations
from typing import Union, Any, cast
import pydantic.generics
from httpx import Timeout
from pydantic import ConfigDict
from typing_extensions import (
Unpack, ClassVar, TypedDict
)
from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
from ._utils import remove_notgiven_indict
class UserRequestInput(TypedDict, total=False):
max_retries: int
timeout: float | Timeout | None
headers: Headers
params: Query | None
class ClientRequestParam():
method: str
url: str
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, NotGiven] = NotGiven()
headers: Union[Headers, NotGiven] = NotGiven()
json_data: Union[Body, None] = None
files: Union[HttpxRequestFiles, None] = None
params: Query = {}
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[UserRequestInput],
) -> ClientRequestParam :
kwargs: dict[str, Any] = {
key: remove_notgiven_indict(value) for key, value in values.items()
}
client = cls()
client.__dict__.update(kwargs)
return client
model_construct = construct
from __future__ import annotations
import datetime
from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
import httpx
import pydantic
from typing_extensions import ParamSpec, get_origin, get_args
from ._base_type import NoneType
from ._sse_client import StreamResponse
if TYPE_CHECKING:
from ._http_client import HttpClient
P = ParamSpec("P")
R = TypeVar("R")
class HttpResponse(Generic[R]):
_cast_type: type[R]
_client: "HttpClient"
_parsed: R | None
_enable_stream: bool
_stream_cls: type[StreamResponse[Any]]
http_response: httpx.Response
def __init__(
self,
*,
raw_response: httpx.Response,
cast_type: type[R],
client: "HttpClient",
enable_stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed = None
self._stream_cls = stream_cls
self._enable_stream = enable_stream
self.http_response = raw_response
def parse(self) -> R:
self._parsed = self._parse()
return self._parsed
def _parse(self) -> R:
if self._enable_stream:
self._parsed = cast(
R,
self._stream_cls(
cast_type=cast(type, get_args(self._stream_cls)[0]),
response=self.http_response,
client=self._client
)
)
return self._parsed
cast_type = self._cast_type
if cast_type is NoneType:
return cast(R, None)
http_response = self.http_response
if cast_type == str:
return cast(R, http_response.text)
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
origin = get_origin(cast_type) or cast_type
if content_type != "application/json":
if issubclass(origin, pydantic.BaseModel):
data = http_response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
)
return http_response.text
data = http_response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
)
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
return self.http_response.content
@property
def text(self) -> str:
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
return self.http_response.elapsed
# -*- coding:utf-8 -*-
from __future__ import annotations
import json
from typing import Generic, Iterator, TYPE_CHECKING, Mapping
import httpx
from ._base_type import ResponseT
from ._errors import APIResponseError
_FIELD_SEPARATOR = ":"
if TYPE_CHECKING:
from ._http_client import HttpClient
class StreamResponse(Generic[ResponseT]):
response: httpx.Response
_cast_type: type[ResponseT]
def __init__(
self,
*,
cast_type: type[ResponseT],
response: httpx.Response,
client: HttpClient,
) -> None:
self.response = response
self._cast_type = cast_type
self._data_process_func = client._process_response_data
self._stream_chunks = self.__stream__()
def __next__(self) -> ResponseT:
return self._stream_chunks.__next__()
def __iter__(self) -> Iterator[ResponseT]:
for item in self._stream_chunks:
yield item
def __stream__(self) -> Iterator[ResponseT]:
sse_line_parser = SSELineParser()
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
if sse.event is None:
data = sse.json_data()
if isinstance(data, Mapping) and data.get("error"):
raise APIResponseError(
message="An error occurred during streaming",
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
for sse in iterator:
pass
class Event(object):
def __init__(
self,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None
):
self._event = event
self._data = data
self._id = id
self._retry = retry
def __repr__(self):
data_len = len(self._data) if self._data else 0
return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
@property
def event(self): return self._event
@property
def data(self): return self._data
def json_data(self): return json.loads(self._data)
@property
def id(self): return self._id
@property
def retry(self): return self._retry
class SSELineParser:
_data: list[str]
_event: str | None
_retry: int | None
_id: str | None
def __init__(self):
self._event = None
self._data = []
self._id = None
self._retry = None
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
for line in lines:
line = line.rstrip('\n')
if not line:
if self._event is None and \
not self._data and \
self._id is None and \
self._retry is None:
continue
sse_event = Event(
event=self._event,
data='\n'.join(self._data),
id=self._id,
retry=self._retry
)
self._event = None
self._data = []
self._id = None
self._retry = None
yield sse_event
self.decode_line(line)
def decode_line(self, line: str):
if line.startswith(":") or not line:
return
field, _p, value = line.partition(":")
if value.startswith(' '):
value = value[1:]
if field == "data":
self._data.append(value)
elif field == "event":
self._event = value
elif field == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
return
from __future__ import annotations
from typing import Mapping, Iterable, TypeVar
from ._base_type import NotGiven
def remove_notgiven_indict(obj):
if obj is None or (not isinstance(obj, Mapping)):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
_T = TypeVar("_T")
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
from typing import List, Optional
from pydantic import BaseModel
from .chat_completion import CompletionChoice, CompletionUsage
__all__ = ["AsyncTaskStatus"]
class AsyncTaskStatus(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: Optional[str] = None
class AsyncCompletion(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: str
choices: List[CompletionChoice]
usage: CompletionUsage
\ No newline at end of file
from typing import List, Optional
from pydantic import BaseModel
__all__ = ["Completion", "CompletionUsage"]
class Function(BaseModel):
arguments: str
name: str
class CompletionMessageToolCall(BaseModel):
id: str
function: Function
type: str
class CompletionMessage(BaseModel):
content: Optional[str] = None
role: str
tool_calls: Optional[List[CompletionMessageToolCall]] = None
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CompletionChoice(BaseModel):
index: int
finish_reason: str
message: CompletionMessage
class Completion(BaseModel):
model: Optional[str] = None
created: Optional[int] = None
choices: List[CompletionChoice]
request_id: Optional[str] = None
id: Optional[str] = None
usage: CompletionUsage
from typing import List, Optional
from pydantic import BaseModel
__all__ = [
"ChatCompletionChunk",
"Choice",
"ChoiceDelta",
"ChoiceDeltaFunctionCall",
"ChoiceDeltaToolCall",
"ChoiceDeltaToolCallFunction",
]
class ChoiceDeltaFunctionCall(BaseModel):
arguments: Optional[str] = None
name: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel):
arguments: Optional[str] = None
name: Optional[str] = None
class ChoiceDeltaToolCall(BaseModel):
index: int
id: Optional[str] = None
function: Optional[ChoiceDeltaToolCallFunction] = None
type: Optional[str] = None
class ChoiceDelta(BaseModel):
content: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class Choice(BaseModel):
delta: ChoiceDelta
finish_reason: Optional[str] = None
index: int
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionChunk(BaseModel):
id: Optional[str] = None
choices: List[Choice]
created: Optional[int] = None
model: Optional[str] = None
usage: Optional[CompletionUsage] = None
from typing import Optional
from typing_extensions import TypedDict
class Reference(TypedDict, total=False):
enable: Optional[bool]
search_query: Optional[str]
from __future__ import annotations
from typing import Optional, List
from pydantic import BaseModel
from .chat.chat_completion import CompletionUsage
__all__ = ["Embedding", "EmbeddingsResponded"]
class Embedding(BaseModel):
object: str
index: Optional[int] = None
embedding: List[float]
class EmbeddingsResponded(BaseModel):
object: str
data: List[Embedding]
model: str
usage: CompletionUsage
from typing import Optional, List
from pydantic import BaseModel
__all__ = ["FileObject"]
class FileObject(BaseModel):
id: Optional[str] = None
bytes: Optional[int] = None
created_at: Optional[int] = None
filename: Optional[str] = None
object: Optional[str] = None
purpose: Optional[str] = None
status: Optional[str] = None
status_details: Optional[str] = None
class ListOfFileObject(BaseModel):
object: Optional[str] = None
data: List[FileObject]
has_more: Optional[bool] = None
from __future__ import annotations
from .fine_tuning_job import FineTuningJob as FineTuningJob
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
from typing import List, Union, Optional
from typing_extensions import Literal
from pydantic import BaseModel
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
class Error(BaseModel):
code: str
message: str
param: Optional[str] = None
class Hyperparameters(BaseModel):
n_epochs: Union[str, int, None] = None
class FineTuningJob(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
created_at: Optional[int] = None
error: Optional[Error] = None
fine_tuned_model: Optional[str] = None
finished_at: Optional[int] = None
hyperparameters: Optional[Hyperparameters] = None
model: Optional[str] = None
object: Optional[str] = None
result_files: List[str]
status: str
trained_tokens: Optional[int] = None
training_file: str
validation_file: Optional[str] = None
class ListOfFineTuningJob(BaseModel):
object: Optional[str] = None
data: List[FineTuningJob]
has_more: Optional[bool] = None
from typing import List, Union, Optional
from typing_extensions import Literal
from pydantic import BaseModel
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
class Metric(BaseModel):
epoch: Optional[Union[str, int, float]] = None
current_steps: Optional[int] = None
total_steps: Optional[int] = None
elapsed_time: Optional[str] = None
remaining_time: Optional[str] = None
trained_tokens: Optional[int] = None
loss: Optional[Union[str, int, float]] = None
eval_loss: Optional[Union[str, int, float]] = None
acc: Optional[Union[str, int, float]] = None
eval_acc: Optional[Union[str, int, float]] = None
learning_rate: Optional[Union[str, int, float]] = None
class JobEvent(BaseModel):
object: Optional[str] = None
id: Optional[str] = None
type: Optional[str] = None
created_at: Optional[int] = None
level: Optional[str] = None
message: Optional[str] = None
data: Optional[Metric] = None
class FineTuningJobEvent(BaseModel):
object: Optional[str] = None
data: List[JobEvent]
has_more: Optional[bool] = None
from __future__ import annotations
from typing import Union
from typing_extensions import Literal, TypedDict
__all__ = ["Hyperparameters"]
class Hyperparameters(TypedDict, total=False):
batch_size: Union[Literal["auto"], int]
learning_rate_multiplier: Union[Literal["auto"], float]
n_epochs: Union[Literal["auto"], int]
from __future__ import annotations
from typing import Optional, List
from pydantic import BaseModel
__all__ = ["GeneratedImage", "ImagesResponded"]
class GeneratedImage(BaseModel):
b64_json: Optional[str] = None
url: Optional[str] = None
revised_prompt: Optional[str] = None
class ImagesResponded(BaseModel):
created: int
data: List[GeneratedImage]
...@@ -26,4 +26,4 @@ class BuiltinToolProviderSort: ...@@ -26,4 +26,4 @@ class BuiltinToolProviderSort:
sorted_providers = sorted(providers, key=sort_compare) sorted_providers = sorted(providers, key=sort_compare)
return sorted_providers return sorted_providers
\ No newline at end of file
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict
class AzureDALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
DallE3Tool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "square",
"n": 1
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: Leslie
name: azuredalle
label:
en_US: AZURE DALL-E
zh_Hans: AZURE DALL-E 绘画
pt_BR: AZURE DALL-E
description:
en_US: AZURE DALL-E art
zh_Hans: AZURE DALL-E 绘画
pt_BR: AZURE DALL-E art
icon: icon.png
credentials_for_provider:
azure_openai_api_key:
type: secret-input
required: true
label:
en_US: API key
zh_Hans: 密钥
pt_BR: API key
help:
en_US: Please input your Azure OpenAI API key
zh_Hans: 请输入你的 Azure OpenAI API key
pt_BR: Please input your Azure OpenAI API key
placeholder:
en_US: Please input your Azure OpenAI API key
zh_Hans: 请输入你的 Azure OpenAI API key
pt_BR: Please input your Azure OpenAI API key
azure_openai_api_model_name:
type: text-input
required: true
label:
en_US: Deployment Name
zh_Hans: 部署名称
pt_BR: Deployment Name
help:
en_US: Please input the name of your Azure Openai DALL-E API deployment
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
pt_BR: Please input the name of your Azure Openai DALL-E API deployment
placeholder:
en_US: Please input the name of your Azure Openai DALL-E API deployment
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
pt_BR: Please input the name of your Azure Openai DALL-E API deployment
azure_openai_base_url:
type: text-input
required: true
label:
en_US: API Endpoint URL
zh_Hans: API 域名
pt_BR: API Endpoint URL
help:
en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
placeholder:
en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/
azure_openai_api_version:
type: text-input
required: true
label:
en_US: API Version
zh_Hans: API 版本
pt_BR: API Version
help:
en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
placeholder:
en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from os.path import join
from openai import AzureOpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = AzureOpenAI(
api_version=self.runtime.credentials['azure_openai_api_version'],
azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
api_key=self.runtime.credentials['azure_openai_api_key'],
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
}
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
# get n
n = tool_paramters.get('n', 1)
# get quality
quality = tool_paramters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
# get style
style = tool_paramters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
# call openapi dalle3
model=self.runtime.credentials['azure_openai_api_model_name']
response = client.images.generate(
prompt=prompt,
model=model,
size=size,
n=n,
style=style,
quality=quality,
response_format='b64_json'
)
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
identity:
name: dalle3
author: Leslie
label:
en_US: DALL-E 3
zh_Hans: DALL-E 3 绘画
pt_BR: DALL-E 3
description:
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
pt_BR: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
description:
human:
en_US: DALL-E is a text to image tool
zh_Hans: DALL-E 是一个文本到图像的工具
pt_BR: DALL-E is a text to image tool
llm: DALL-E is a tool used to generate images from text
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
pt_BR: Prompt
human_description:
en_US: Image prompt, you can check the official documentation of DallE 3
zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档
pt_BR: Image prompt, you can check the official documentation of DallE 3
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
form: llm
- name: size
type: select
required: true
human_description:
en_US: selecting the image size
zh_Hans: 选择图像大小
pt_BR: selecting the image size
label:
en_US: Image size
zh_Hans: 图像大小
pt_BR: Image size
form: form
options:
- value: square
label:
en_US: Squre(1024x1024)
zh_Hans: 方(1024x1024)
pt_BR: Squre(1024x1024)
- value: vertical
label:
en_US: Vertical(1024x1792)
zh_Hans: 竖屏(1024x1792)
pt_BR: Vertical(1024x1792)
- value: horizontal
label:
en_US: Horizontal(1792x1024)
zh_Hans: 横屏(1792x1024)
pt_BR: Horizontal(1792x1024)
default: square
- name: n
type: number
required: true
human_description:
en_US: selecting the number of images
zh_Hans: 选择图像数量
pt_BR: selecting the number of images
label:
en_US: Number of images
zh_Hans: 图像数量
pt_BR: Number of images
form: form
min: 1
max: 1
default: 1
- name: quality
type: select
required: true
human_description:
en_US: selecting the image quality
zh_Hans: 选择图像质量
pt_BR: selecting the image quality
label:
en_US: Image quality
zh_Hans: 图像质量
pt_BR: Image quality
form: form
options:
- value: standard
label:
en_US: Standard
zh_Hans: 标准
pt_BR: Standard
- value: hd
label:
en_US: HD
zh_Hans: 高清
pt_BR: HD
default: standard
- name: style
type: select
required: true
human_description:
en_US: selecting the image style
zh_Hans: 选择图像风格
pt_BR: selecting the image style
label:
en_US: Image style
zh_Hans: 图像风格
pt_BR: Image style
form: form
options:
- value: vivid
label:
en_US: Vivid
zh_Hans: 生动
pt_BR: Vivid
- value: natural
label:
en_US: Natural
zh_Hans: 自然
pt_BR: Natural
default: vivid
import requests
import urllib.parse
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
class GaodeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolProviderCredentialValidationError("Gaode API key is required.")
try:
response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}"
"".format(address=urllib.parse.quote('广东省广州市天河区广州塔'),
apikey=credentials.get('api_key')))
if response.status_code == 200 and (response.json()).get('info') == 'OK':
pass
else:
raise ToolProviderCredentialValidationError((response.json()).get('info'))
except Exception as e:
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: CharlirWei
name: gaode
label:
en_US: GaoDe
zh_Hans: 高德
pt_BR: GaoDe
description:
en_US: Autonavi Open Platform service toolkit.
zh_Hans: 高德开放平台服务工具包。
pt_BR: Kit de ferramentas de serviço Autonavi Open Platform.
icon: icon.png
credentials_for_provider:
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API Key
pt_BR: Fogo a chave
placeholder:
en_US: Please enter your GaoDe API Key
zh_Hans: 请输入你的高德开放平台 API Key
pt_BR: Insira sua chave de API GaoDe
help:
en_US: Get your API Key from GaoDe
zh_Hans: 从高德获取您的 API Key
pt_BR: Obtenha sua chave de API do GaoDe
url: https://console.amap.com/dev/key/app
import json
import requests
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GaodeRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
city = tool_paramters.get('city', '')
if not city:
return self.create_text_message('Please tell me your city')
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
return self.create_text_message("Gaode API key is required.")
try:
s = requests.session()
api_domain = 'https://restapi.amap.com/v3'
city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"},
url="{url}/config/district?keywords={keywords}"
"&subdistrict=0&extensions=base&key={apikey}"
"".format(url=api_domain, keywords=city,
apikey=self.runtime.credentials.get('api_key')))
City_data = city_response.json()
if city_response.status_code == 200 and City_data.get('info') == 'OK':
if len(City_data.get('districts')) > 0:
CityCode = City_data['districts'][0]['adcode']
weatherInfo_response = s.request(method='GET',
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
"".format(url=api_domain, citycode=CityCode,
apikey=self.runtime.credentials.get('api_key')))
weatherInfo_data = weatherInfo_response.json()
if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
contents = list()
if len(weatherInfo_data.get('forecasts')) > 0:
for item in weatherInfo_data['forecasts'][0]['casts']:
content = dict()
content['date'] = item.get('date')
content['week'] = item.get('week')
content['dayweather'] = item.get('dayweather')
content['daytemp_float'] = item.get('daytemp_float')
content['daywind'] = item.get('daywind')
content['nightweather'] = item.get('nightweather')
content['nighttemp_float'] = item.get('nighttemp_float')
contents.append(content)
s.close()
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
s.close()
return self.create_text_message(f'No weather information for {city} was found.')
except Exception as e:
return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))
identity:
name: gaode_weather
author: CharlieWei
label:
en_US: Weather Forecast
zh_Hans: 天气预报
pt_BR: Previsão do tempo
icon: icon.svg
description:
human:
en_US: Weather forecast inquiry
zh_Hans: 天气预报查询。
pt_BR: Inquérito sobre previsão meteorológica.
llm: A tool when you want to ask about the weather or weather-related question.
parameters:
- name: city
type: string
required: true
label:
en_US: city
zh_Hans: 城市
pt_BR: cidade
human_description:
en_US: Target city for weather forecast query.
zh_Hans: 天气预报查询的目标城市。
pt_BR: Cidade de destino para consulta de previsão do tempo.
llm_description: If you don't know you can extract the city name from the question or you can reply:Please tell me your city. You have to extract the Chinese city name from the question.
form: llm
import requests
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
class GihubProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
raise ToolProviderCredentialValidationError("Github API Access Tokens is required.")
if 'api_version' not in credentials or not credentials.get('api_version'):
api_version = '2022-11-28'
else:
api_version = credentials.get('api_version')
try:
headers = {
"Content-Type": "application/vnd.github+json",
"Authorization": f"Bearer {credentials.get('access_tokens')}",
"X-GitHub-Api-Version": api_version
}
response = requests.get(
url="https://api.github.com/search/users?q={account}".format(account='charli117'),
headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e:
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: CharlirWei
name: github
label:
en_US: Github
zh_Hans: Github
pt_BR: Github
description:
en_US: GitHub is an online software source code hosting service.
zh_Hans: GitHub是一个在线软件源代码托管服务平台。
pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software.
icon: icon.png
credentials_for_provider:
access_tokens:
type: secret-input
required: true
label:
en_US: Access Tokens
zh_Hans: Access Tokens
pt_BR: Tokens de acesso
placeholder:
en_US: Please input your Github Access Tokens
zh_Hans: 请输入你的 Github Access Tokens
pt_BR: Insira seus Tokens de Acesso do Github
help:
en_US: Get your Access Tokens from Github
zh_Hans: 从 Github 获取您的 Access Tokens
pt_BR: Obtenha sua chave da API do Google no Google
url: https://github.com/settings/tokens?type=beta
api_version:
type: text-input
required: false
default: '2022-11-28'
label:
en_US: API Version
zh_Hans: API Version
pt_BR: Versão da API
placeholder:
en_US: Please input your Github API Version
zh_Hans: 请输入你的 Github API Version
pt_BR: Insira sua versão da API do Github
help:
en_US: Get your API Version from Github
zh_Hans: 从 Github 获取您的 API Version
pt_BR: Obtenha sua versão da API do Github
url: https://docs.github.com/en/rest/about-the-rest-api/api-versions?apiVersion=2022-11-28
import json
import requests
from datetime import datetime
from urllib.parse import quote
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GihubRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
top_n = tool_paramters.get('top_n', 5)
query = tool_paramters.get('query', '')
if not query:
return self.create_text_message('Please input symbol')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Github API Access Tokens is required.")
if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'):
api_version = '2022-11-28'
else:
api_version = self.runtime.credentials.get('api_version')
try:
headers = {
"Content-Type": "application/vnd.github+json",
"Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}",
"X-GitHub-Api-Version": api_version
}
s = requests.session()
api_domain = 'https://api.github.com'
response = s.request(method='GET', headers=headers,
url=f"{api_domain}/search/repositories?"
f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
response_data = response.json()
if response.status_code == 200 and isinstance(response_data.get('items'), list):
contents = list()
if len(response_data.get('items')) > 0:
for item in response_data.get('items'):
content = dict()
updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
content['owner'] = item['owner']['login']
content['name'] = item['name']
content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description']
content['url'] = item['html_url']
content['star'] = item['watchers']
content['forks'] = item['forks']
content['updated'] = updated_at_object.strftime("%Y-%m-%d")
contents.append(content)
s.close()
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
else:
return self.create_text_message(f'No items related to {query} were found.')
else:
return self.create_text_message((response.json()).get('message'))
except Exception as e:
return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))
identity:
name: repositories
author: CharlieWei
label:
en_US: Search Repositories
zh_Hans: 仓库搜索
pt_BR: Pesquisar Repositórios
icon: icon.svg
description:
human:
en_US: Search the Github repository to retrieve the open source projects you need
zh_Hans: 搜索Github仓库,检索你需要的开源项目。
pt_BR: Pesquise o repositório do Github para recuperar os projetos de código aberto necessários.
llm: A tool when you wants to search for popular warehouses or open source projects for any keyword. format query condition like "keywords+language:js", language can be other dev languages.
parameters:
- name: query
type: string
required: true
label:
en_US: query
zh_Hans: 关键字
pt_BR: consulta
human_description:
en_US: You want to find the project development language, keywords, For example. Find 10 Python developed PDF document parsing projects.
zh_Hans: 你想要找的项目开发语言、关键字,如:找10个Python开发的PDF文档解析项目。
pt_BR: Você deseja encontrar a linguagem de desenvolvimento do projeto, palavras-chave, Por exemplo. Encontre 10 projetos de análise de documentos PDF desenvolvidos em Python.
llm_description: The query of you want to search, format query condition like "keywords+language:js", language can be other dev languages, por exemplo. Procuro um projeto de análise de documentos PDF desenvolvido em Python.
form: llm
- name: top_n
type: number
default: 5
required: true
label:
en_US: Top N
zh_Hans: Top N
pt_BR: Topo N
human_description:
en_US: Number of records returned by sorting based on stars. 5 is returned by default.
zh_Hans: 基于stars排序返回的记录数, 默认返回5条。
pt_BR: Número de registros retornados por classificação com base em estrelas. 5 é retornado por padrão.
llm_description: Extract the first N records from the returned result.
form: llm
...@@ -40,7 +40,7 @@ class BuiltinToolProviderController(ToolProviderController): ...@@ -40,7 +40,7 @@ class BuiltinToolProviderController(ToolProviderController):
'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None, 'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None,
}) })
def _get_bulitin_tools(self) -> List[Tool]: def _get_builtin_tools(self) -> List[Tool]:
""" """
returns a list of tools that the provider can provide returns a list of tools that the provider can provide
...@@ -101,7 +101,7 @@ class BuiltinToolProviderController(ToolProviderController): ...@@ -101,7 +101,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: list of tools :return: list of tools
""" """
return self._get_bulitin_tools() return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> Tool: def get_tool(self, tool_name: str) -> Tool:
""" """
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment