Unverified Commit 91ee62d1 authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

fix: huggingface and replicate. (#1888)

parent ede69b46
......@@ -154,6 +154,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text
)
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
......@@ -166,6 +167,16 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
index=index,
message=assistant_prompt_message,
usage=usage,
finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
......
......@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
)
for key, value in input_properties:
if key not in ['system_prompt', 'prompt']:
if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type')
if not value_type:
......@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1
current_completion: str = ""
stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator():
current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop:
for s in stop:
if s in current_completion:
......@@ -172,6 +180,16 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else ''
)
if index < prediction_output_length:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
......@@ -183,8 +201,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
usage=usage
)
)
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment