Unverified Commit 14a2eeba authored by Chenhe Gu's avatar Chenhe Gu Committed by GitHub

Add bedrock (#2119)

Co-authored-by: 's avatartakatost <takatost@users.noreply.github.com>
Co-authored-by: 's avatarGarfield Dai <dai.hai@foxmail.com>
Co-authored-by: 's avatarJoel <iamjoel007@gmail.com>
Co-authored-by: 's avatarcrazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: 's avatarCharlie.Wei <luowei@cvte.com>
Co-authored-by: 's avatarcrazywoola <427733928@qq.com>
Co-authored-by: 's avatarBenjamin <benjaminx@gmail.com>
parent a18dde9b
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_16762_59518)">
<path d="M12.6667 0H3.33333C1.49238 0 0 1.49238 0 3.33333V12.6667C0 14.5076 1.49238 16 3.33333 16H12.6667C14.5076 16 16 14.5076 16 12.6667V3.33333C16 1.49238 14.5076 0 12.6667 0Z" fill="url(#paint0_linear_16762_59518)"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.99984 12.093L6.3825 12.6323L5.75184 12.2116L6.4385 11.9823L6.22784 11.3503L5.04917 11.743L4.6665 11.4883V9.66631C4.6665 9.54031 4.59517 9.42497 4.4825 9.3683L3.33317 8.79364V7.20564L4.33317 6.70564L5.33317 7.20564V8.33297C5.33317 8.45964 5.4045 8.57497 5.51717 8.63164L6.8505 9.29831L7.14917 8.70164L5.99984 8.12697V7.20564L7.14917 6.63164C7.26184 6.57497 7.33317 6.45964 7.33317 6.33297V5.33297H6.6665V6.12697L5.6665 6.62697L4.6665 6.12697V4.51164L5.33317 4.06697V5.33297H5.99984V3.62297L6.3825 3.36764L7.99984 3.90697V12.093ZM11.6665 11.333C11.8498 11.333 11.9998 11.4823 11.9998 11.6663C11.9998 11.8503 11.8498 11.9996 11.6665 11.9996C11.4832 11.9996 11.3332 11.8503 11.3332 11.6663C11.3332 11.4823 11.4832 11.333 11.6665 11.333ZM10.9998 3.99964C11.1832 3.99964 11.3332 4.14897 11.3332 4.33297C11.3332 4.51697 11.1832 4.6663 10.9998 4.6663C10.8165 4.6663 10.6665 4.51697 10.6665 4.33297C10.6665 4.14897 10.8165 3.99964 10.9998 3.99964ZM12.3332 7.99964C12.5165 7.99964 12.6665 8.14897 12.6665 8.33297C12.6665 8.51697 12.5165 8.66631 12.3332 8.66631C12.1498 8.66631 11.9998 8.51697 11.9998 8.33297C11.9998 8.14897 12.1498 7.99964 12.3332 7.99964ZM11.3945 8.66631C11.5325 9.05364 11.8992 9.33297 12.3332 9.33297C12.8845 9.33297 13.3332 8.88497 13.3332 8.33297C13.3332 7.78164 12.8845 7.33297 12.3332 7.33297C11.8992 7.33297 11.5325 7.61297 11.3945 7.99964H8.6665V6.66631H10.9998C11.1838 6.66631 11.3332 6.51764 11.3332 6.33297V5.27164C11.7205 5.13364 11.9998 4.76697 11.9998 4.33297C11.9998 3.78164 11.5512 3.33297 10.9998 3.33297C10.4485 3.33297 9.99984 3.78164 9.99984 4.33297C9.99984 4.76697 10.2792 5.13364 10.6665 5.27164V5.99964H8.6665V3.6663C8.6665 3.52297 8.5745 3.39564 8.4385 3.3503L6.4385 2.68364C6.3405 2.65097 6.23384 2.66564 6.1485 2.7223L4.1485 4.05564C4.05584 4.11764 3.99984 4.22164 3.99984 4.33297V6.12697L2.8505 6.70164C2.73784 6.75831 2.6665 6.87364 2.6665 6.99964V8.99964C2.6665 9.12631 2.73784 9.24164 2.8505 9.29831L3.99984 9.87231V11.6663C3.99984 11.7776 4.05584 11.8823 4.1485 11.9436L6.1485 13.277C6.20384 13.3143 6.26784 13.333 6.33317 13.333C6.3685 13.333 6.40384 13.3276 6.4385 13.3156L8.4385 12.649C8.5745 12.6043 8.6665 12.477 8.6665 12.333V10.6663H10.1952L10.7638 11.2356L10.7725 11.227C10.7072 11.3603 10.6665 11.5083 10.6665 11.6663C10.6665 12.2176 11.1152 12.6663 11.6665 12.6663C12.2178 12.6663 12.6665 12.2176 12.6665 11.6663C12.6665 11.115 12.2178 10.6663 11.6665 10.6663C11.5078 10.6663 11.3598 10.707 11.2272 10.773L11.2358 10.7643L10.5692 10.0976C10.5065 10.035 10.4218 9.99964 10.3332 9.99964H8.6665V8.66631H11.3945Z" fill="white"/>
</g>
<defs>
<linearGradient id="paint0_linear_16762_59518" x1="0" y1="1600" x2="1600" y2="0" gradientUnits="userSpaceOnUse">
<stop stop-color="#055F4E"/>
<stop offset="1" stop-color="#56C0A7"/>
</linearGradient>
<clipPath id="clip0_16762_59518">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class BedrockProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `gemini-pro` model for validate,
model_instance.validate_credentials(
model='amazon.titan-text-lite-v1',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
provider: bedrock
label:
en_US: AWS
description:
en_US: AWS Bedrock's models.
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FCFDFF"
help:
title:
en_US: Get your Access Key and Secret Access Key from AWS Console
url:
en_US: https://console.aws.amazon.com/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: aws_access_key_id
required: true
label:
en_US: Access Key
zh_Hans: Access Key
type: secret-input
placeholder:
en_US: Enter your Access Key
zh_Hans: 在此输入您的 Access Key
- variable: aws_secret_access_key
required: true
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
type: secret-input
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 在此输入您的 Secret Access Key
- variable: aws_region
required: true
label:
en_US: AWS Region
zh_Hans: AWS 地区
type: select
default: us-east-1
options:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: US East (N. Virginia)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: US West (Oregon)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: Asia Pacific (Singapore)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: Asia Pacific (Tokyo)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: Europe (Frankfurt)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- amazon.titan-text-express-v1
- amazon.titan-text-lite-v1
- anthropic.claude-instant-v1
- anthropic.claude-v1
- anthropic.claude-v2
- anthropic.claude-v2:1
- cohere.command-light-text-v14
- cohere.command-text-v14
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
model: ai21.j2-mid-v1
label:
en_US: J2 Mid V1
model_type: llm
model_properties:
mode: completion
context_size: 8191
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
- name: count_penalty
label:
en_US: Count Penalty
required: false
type: float
default: 0
min: 0
max: 1
- name: presence_penalty
label:
en_US: Presence Penalty
required: false
type: float
default: 0
min: 0
max: 5
- name: frequency_penalty
label:
en_US: Frequency Penalty
required: false
type: float
default: 0
min: 0
max: 500
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD
model: ai21.j2-ultra-v1
label:
en_US: J2 Ultra V1
model_type: llm
model_properties:
mode: completion
context_size: 8191
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
- name: count_penalty
label:
en_US: Count Penalty
required: false
type: float
default: 0
min: 0
max: 1
- name: presence_penalty
label:
en_US: Presence Penalty
required: false
type: float
default: 0
min: 0
max: 5
- name: frequency_penalty
label:
en_US: Frequency Penalty
required: false
type: float
default: 0
min: 0
max: 500
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD
model: amazon.titan-text-express-v1
label:
en_US: Titan Text G1 - Express
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokenCount
use_template: max_tokens
required: true
default: 2048
min: 1
max: 8000
pricing:
input: '0.0008'
output: '0.0016'
unit: '0.001'
currency: USD
model: amazon.titan-text-lite-v1
label:
en_US: Titan Text G1 - Lite
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokenCount
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.0003'
output: '0.0004'
unit: '0.001'
currency: USD
model: anthropic.claude-instant-v1
label:
en_US: Claude Instant V1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: topK
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0008'
output: '0.0024'
unit: '0.001'
currency: USD
model: anthropic.claude-v1
label:
en_US: Claude V1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: topK
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
model: anthropic.claude-v2
label:
en_US: Claude V2
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: topK
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
model: anthropic.claude-v2:1
label:
en_US: Claude V2.1
model_type: llm
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: topK
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
model: cohere.command-light-text-v14
label:
en_US: Command Light Text V14
model_type: llm
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: p
use_template: top_p
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
min: 0
max: 500
default: 0
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0003'
output: '0.0006'
unit: '0.001'
currency: USD
model: cohere.command-text-v14
label:
en_US: Command Text V14
model_type: llm
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0015'
output: '0.0020'
unit: '0.001'
currency: USD
This diff is collapsed.
model: meta.llama2-13b-chat-v1
label:
en_US: Llama 2 Chat 13B
model_type: llm
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.00075'
output: '0.00100'
unit: '0.001'
currency: USD
model: meta.llama2-70b-chat-v1
label:
en_US: Llama 2 Chat 70B
model_type: llm
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.00195'
output: '0.00256'
unit: '0.001'
currency: USD
import os
from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
def test_validate_credentials():
model = BedrockLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
}
)
def test_invoke_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens_to_sample': 10
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens_to_sample': 100
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
print(chunk)
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = BedrockLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='meta.llama2-13b-chat-v1',
credentials = {
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
},
messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 18
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider
def test_validate_provider_credentials():
provider = BedrockProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
}
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment