Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
E
elleai
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
E
ellehuis-group
backend
elleai
Commits
32b01ebc
Commit
32b01ebc
authored
Oct 22, 2024
by
yangyw
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feature: 添加langchain4j支持
parent
cb8d2a58
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
1051 additions
and
1 deletion
+1051
-1
.gitignore
.gitignore
+13
-1
pom.xml
pom.xml
+28
-0
EmbeddingItem.java
...reeze/elleai/application/dto/langchain/EmbeddingItem.java
+23
-0
EmbeddingRequest.java
...ze/elleai/application/dto/langchain/EmbeddingRequest.java
+20
-0
EmbeddingResponse.java
...e/elleai/application/dto/langchain/EmbeddingResponse.java
+19
-0
RagSearchRequest.java
...ze/elleai/application/dto/langchain/RagSearchRequest.java
+26
-0
RerankItem.java
...n/breeze/elleai/application/dto/langchain/RerankItem.java
+20
-0
RerankRequest.java
...reeze/elleai/application/dto/langchain/RerankRequest.java
+28
-0
RerankResponse.java
...eeze/elleai/application/dto/langchain/RerankResponse.java
+21
-0
Target.java
...va/cn/breeze/elleai/application/dto/langchain/Target.java
+23
-0
Usage.java
...ava/cn/breeze/elleai/application/dto/langchain/Usage.java
+23
-0
VectorSearchRequest.java
...elleai/application/dto/langchain/VectorSearchRequest.java
+29
-0
VectorSegment.java
...reeze/elleai/application/dto/langchain/VectorSegment.java
+30
-0
AIService.java
.../java/cn/breeze/elleai/application/service/AIService.java
+138
-0
EmbeddingService.java
src/main/java/cn/breeze/elleai/facade/EmbeddingService.java
+34
-0
MilvusVectorStoreFacade.java
...java/cn/breeze/elleai/facade/MilvusVectorStoreFacade.java
+55
-0
OpenAIEmbeddingFacade.java
...n/java/cn/breeze/elleai/facade/OpenAIEmbeddingFacade.java
+62
-0
RerankFacade.java
src/main/java/cn/breeze/elleai/facade/RerankFacade.java
+92
-0
VectorStoreService.java
...main/java/cn/breeze/elleai/facade/VectorStoreService.java
+195
-0
NoBillCohereScoringModel.java
...ev/langchain4j/model/cohere/NoBillCohereScoringModel.java
+86
-0
AIServiceTestUnit.java
src/test/java/cn/breeze/elleai/test/AIServiceTestUnit.java
+86
-0
No files found.
.gitignore
View file @
32b01ebc
/target/
HELP.md
target/
!.mvn/wrapper/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
### IntelliJ IDEA ###
.idea
*.iws
*.iml
*.ipr
*.mvn
\ No newline at end of file
pom.xml
View file @
32b01ebc
...
...
@@ -28,6 +28,7 @@
</scm>
<properties>
<java.version>
17
</java.version>
<langchain4j.version>
0.35.0
</langchain4j.version>
</properties>
<dependencies>
<dependency>
...
...
@@ -149,6 +150,33 @@
<version>
2.5.1
</version>
</dependency>
<dependency>
<groupId>
dev.langchain4j
</groupId>
<artifactId>
langchain4j-open-ai
</artifactId>
<version>
${langchain4j.version}
</version>
</dependency>
<dependency>
<groupId>
dev.langchain4j
</groupId>
<artifactId>
langchain4j
</artifactId>
<version>
${langchain4j.version}
</version>
</dependency>
<dependency>
<groupId>
dev.langchain4j
</groupId>
<artifactId>
langchain4j-core
</artifactId>
<version>
${langchain4j.version}
</version>
</dependency>
<dependency>
<groupId>
dev.langchain4j
</groupId>
<artifactId>
langchain4j-cohere
</artifactId>
<version>
${langchain4j.version}
</version>
</dependency>
<dependency>
<groupId>
dev.langchain4j
</groupId>
<artifactId>
langchain4j-milvus
</artifactId>
<version>
${langchain4j.version}
</version>
</dependency>
</dependencies>
<build>
...
...
src/main/java/cn/breeze/elleai/application/dto/langchain/EmbeddingItem.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"向量化结果"
)
public
class
EmbeddingItem
implements
Serializable
{
@Schema
(
description
=
"索引"
)
private
Integer
index
;
@Schema
(
description
=
"向量"
)
private
float
[]
embedding
;
@Schema
(
description
=
"原始内容"
)
private
String
content
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/EmbeddingRequest.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
lombok.EqualsAndHashCode
;
import
java.util.List
;
/**
* @author yangyw
*/
@EqualsAndHashCode
(
callSuper
=
true
)
@Data
@Schema
(
description
=
"Embedding请求"
)
public
class
EmbeddingRequest
extends
Target
{
/**
* 文本列表
*/
private
List
<
String
>
texts
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/EmbeddingResponse.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
import
java.util.List
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"Embedding响应"
)
public
class
EmbeddingResponse
implements
Serializable
{
private
Usage
usage
;
private
List
<
EmbeddingItem
>
data
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/RagSearchRequest.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
lombok.EqualsAndHashCode
;
import
lombok.NoArgsConstructor
;
/**
* @author yangyw
*/
@EqualsAndHashCode
(
callSuper
=
true
)
@Data
@Schema
(
description
=
"rag搜索请求"
)
@NoArgsConstructor
public
class
RagSearchRequest
extends
VectorSearchRequest
{
@Schema
(
description
=
"是否启用rerank"
)
private
Boolean
enableRerank
;
@Schema
(
description
=
"rerank结果数量"
)
private
Integer
topKRerank
;
@Schema
(
description
=
"rerank最小相似度"
)
private
Double
minScoreRerank
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/RerankItem.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"rerank结果"
)
public
class
RerankItem
implements
Serializable
{
private
Integer
index
;
private
String
content
;
private
Double
relevanceScore
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/RerankRequest.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
lombok.EqualsAndHashCode
;
import
java.util.List
;
/**
* @author yangyw
*/
@EqualsAndHashCode
(
callSuper
=
true
)
@Data
@Schema
(
description
=
"rerank请求"
)
public
class
RerankRequest
extends
Target
{
@Schema
(
description
=
"返回数量"
)
private
Integer
topN
;
@Schema
(
description
=
"阈值"
)
private
Double
scoreThreshold
;
@Schema
(
description
=
"查询内容"
)
private
String
query
;
@Schema
(
description
=
"待rerank的向量"
)
private
List
<
String
>
documents
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/RerankResponse.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
import
java.util.List
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"rerank响应"
)
public
class
RerankResponse
implements
Serializable
{
@Schema
(
description
=
"Usage"
)
private
Usage
usage
;
@Schema
(
description
=
"rerank结果"
)
private
List
<
RerankItem
>
results
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/Target.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"目标平台"
)
public
class
Target
implements
Serializable
{
@Schema
(
description
=
"API地址"
,
hidden
=
true
)
private
String
apiBaseUrl
;
@Schema
(
description
=
"API Key"
,
hidden
=
true
)
private
String
apiKey
;
@Schema
(
description
=
"模型名称"
,
hidden
=
true
)
private
String
model
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/Usage.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"Usage"
)
public
class
Usage
implements
Serializable
{
@Schema
(
description
=
"总Tokens"
)
private
Integer
totalTokens
;
@Schema
(
description
=
"提示Tokens"
)
private
Integer
promptTokens
;
@Schema
(
description
=
"回复Tokens"
)
private
Integer
completionTokens
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/VectorSearchRequest.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.*
;
import
java.util.Map
;
/**
* @author yangyw
*/
@EqualsAndHashCode
(
callSuper
=
true
)
@Data
@Schema
(
description
=
"向量搜索请求"
)
@NoArgsConstructor
@AllArgsConstructor
@Builder
public
class
VectorSearchRequest
extends
Target
{
@Schema
(
description
=
"查询文本"
)
private
String
query
;
@Schema
(
description
=
"返回结果数量"
)
private
Integer
topK
;
@Schema
(
description
=
"最小相似度"
)
private
Double
minScore
;
private
Map
<
String
,
?>
metadata
;
}
src/main/java/cn/breeze/elleai/application/dto/langchain/VectorSegment.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
dto
.
langchain
;
import
io.swagger.v3.oas.annotations.media.Schema
;
import
lombok.Data
;
import
java.io.Serializable
;
import
java.util.Map
;
/**
* @author yangyw
*/
@Data
@Schema
(
description
=
"向量存储的文本片段"
)
public
class
VectorSegment
implements
Serializable
{
@Schema
(
description
=
"文本内容"
)
private
String
content
;
@Schema
(
description
=
"存储片段id"
)
private
String
id
;
@Schema
(
description
=
"元数据"
)
private
Map
<
String
,?>
metadata
;
@Schema
(
description
=
"相似度"
)
private
Double
score
;
@Schema
(
description
=
"重排后的相似度"
)
private
Double
relevanceScore
;
}
src/main/java/cn/breeze/elleai/application/service/AIService.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
application
.
service
;
import
cn.breeze.elleai.application.dto.langchain.*
;
import
cn.breeze.elleai.facade.EmbeddingService
;
import
cn.breeze.elleai.facade.RerankFacade
;
import
cn.breeze.elleai.facade.VectorStoreService
;
import
cn.hutool.core.collection.CollUtil
;
import
cn.hutool.core.util.ObjectUtil
;
import
dev.langchain4j.data.document.Metadata
;
import
dev.langchain4j.data.embedding.Embedding
;
import
dev.langchain4j.data.segment.TextSegment
;
import
dev.langchain4j.store.embedding.EmbeddingMatch
;
import
dev.langchain4j.store.embedding.EmbeddingSearchResult
;
import
lombok.RequiredArgsConstructor
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.stereotype.Component
;
import
java.util.ArrayList
;
import
java.util.List
;
import
java.util.Map
;
/**
* AI服务
*/
@Component
@Slf4j
@RequiredArgsConstructor
public
class
AIService
{
private
final
VectorStoreService
vectorStoreService
;
private
final
EmbeddingService
embeddingService
;
private
final
RerankFacade
rerankFacade
;
/**
* 将知识项存储到向量数据库,返回向量id列表
* @param segments
* @return
*/
public
List
<
String
>
addVectorSegments
(
List
<
VectorSegment
>
segments
)
{
List
<
String
>
texts
=
CollUtil
.
map
(
segments
,
VectorSegment:
:
getContent
,
true
);
List
<
Embedding
>
embeddings
=
embeddingService
.
embed
(
texts
);
List
<
TextSegment
>
textSegments
=
CollUtil
.
newArrayList
();
for
(
VectorSegment
segment
:
segments
)
{
TextSegment
textSegment
=
TextSegment
.
from
(
segment
.
getContent
(),
Metadata
.
from
(
ObjectUtil
.
defaultIfNull
(
segment
.
getMetadata
(),
Map
.
of
())));
textSegments
.
add
(
textSegment
);
}
return
vectorStoreService
.
addSegments
(
embeddings
,
textSegments
);
}
/**
* 根据向量ID批量删除向量
* @param segmentIds
*/
public
void
removeSegments
(
List
<
String
>
segmentIds
)
{
log
.
warn
(
"批量删除向量:{}"
,
segmentIds
);
vectorStoreService
.
removeSegments
(
segmentIds
);
}
/**
* 根据元数据匹配删除向量(仅支持eq 和 in)
* @param metadata
*/
public
void
removeAll
(
Map
<
String
,
?>
metadata
)
{
vectorStoreService
.
removeAll
(
metadata
);
}
/**
* 向量数据库搜索
* @param request
* @return
*/
public
List
<
VectorSegment
>
search
(
VectorSearchRequest
request
)
{
EmbeddingRequest
embeddingRequest
=
new
EmbeddingRequest
();
embeddingRequest
.
setTexts
(
List
.
of
(
request
.
getQuery
()));
embeddingRequest
.
setModel
(
request
.
getModel
());
embeddingRequest
.
setApiKey
(
request
.
getApiKey
());
embeddingRequest
.
setApiBaseUrl
(
request
.
getApiBaseUrl
());
List
<
Embedding
>
embeddings
=
embeddingService
.
embed
(
embeddingRequest
);
Embedding
embedding
=
CollUtil
.
getFirst
(
embeddings
);
if
(
ObjectUtil
.
isNotNull
(
embedding
))
{
EmbeddingSearchResult
<
TextSegment
>
result
=
vectorStoreService
.
search
(
embedding
,
ObjectUtil
.
defaultIfNull
(
request
.
getTopK
(),
10
),
ObjectUtil
.
defaultIfNull
(
request
.
getMinScore
(),
0.0d
),
ObjectUtil
.
defaultIfNull
(
request
.
getMetadata
(),
Map
.
of
()));
if
(
ObjectUtil
.
isNotNull
(
result
)
&&
CollUtil
.
isNotEmpty
(
result
.
matches
()))
{
List
<
VectorSegment
>
segments
=
new
ArrayList
<>();
for
(
EmbeddingMatch
<
TextSegment
>
match
:
result
.
matches
())
{
VectorSegment
segment
=
new
VectorSegment
();
segment
.
setContent
(
match
.
embedded
().
text
());
segment
.
setId
(
match
.
embeddingId
());
segment
.
setMetadata
(
match
.
embedded
().
metadata
().
toMap
());
segment
.
setScore
(
match
.
score
());
segment
.
setRelevanceScore
(
match
.
score
());
segments
.
add
(
segment
);
}
return
segments
;
}
}
return
List
.
of
();
}
/**
* 向量检索并支持重排
* @param request
* @return
*/
public
List
<
VectorSegment
>
searchWithRerank
(
RagSearchRequest
request
)
{
List
<
VectorSegment
>
segments
=
this
.
search
(
request
);
if
(
ObjectUtil
.
equals
(
request
.
getEnableRerank
(),
true
))
{
//对向量查询结果继续rerank
if
(
ObjectUtil
.
isNotEmpty
(
segments
))
{
RerankRequest
rerankRequest
=
new
RerankRequest
();
rerankRequest
.
setQuery
(
request
.
getQuery
());
rerankRequest
.
setTopN
(
ObjectUtil
.
defaultIfNull
(
request
.
getTopKRerank
(),
5
));
if
(
rerankRequest
.
getTopN
()
>
segments
.
size
())
{
rerankRequest
.
setTopN
(
segments
.
size
());
}
rerankRequest
.
setScoreThreshold
(
ObjectUtil
.
defaultIfNull
(
request
.
getMinScoreRerank
(),
0.0
));
rerankRequest
.
setDocuments
(
CollUtil
.
map
(
segments
,
VectorSegment:
:
getContent
,
true
));
RerankResponse
rerankResponse
=
rerankFacade
.
rerank
(
rerankRequest
);
if
(
ObjectUtil
.
isNotNull
(
rerankResponse
)
&&
ObjectUtil
.
isNotEmpty
(
rerankResponse
.
getResults
()))
{
List
<
VectorSegment
>
results
=
new
ArrayList
<>();
for
(
RerankItem
result
:
rerankResponse
.
getResults
())
{
VectorSegment
segment
=
segments
.
get
(
result
.
getIndex
());
segment
.
setRelevanceScore
(
result
.
getRelevanceScore
());
results
.
add
(
segment
);
}
return
results
;
}
else
{
return
List
.
of
();
}
}
}
return
segments
;
}
}
src/main/java/cn/breeze/elleai/facade/EmbeddingService.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
facade
;
import
cn.breeze.elleai.application.dto.langchain.EmbeddingRequest
;
import
dev.langchain4j.data.embedding.Embedding
;
import
java.util.List
;
/**
* @author yangyw
*/
public
interface
EmbeddingService
{
/**
* 批量获取向量
* @param texts
* @return
*/
List
<
Embedding
>
embed
(
List
<
String
>
texts
);
/**
* 获取向量
* @param text
* @return
*/
Embedding
embed
(
String
text
);
/**
* 获取向量
* @param request
* @return
*/
List
<
Embedding
>
embed
(
EmbeddingRequest
request
);
}
src/main/java/cn/breeze/elleai/facade/MilvusVectorStoreFacade.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
facade
;
import
dev.langchain4j.data.segment.TextSegment
;
import
dev.langchain4j.store.embedding.EmbeddingStore
;
import
dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore
;
import
jakarta.annotation.PostConstruct
;
import
lombok.RequiredArgsConstructor
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.beans.factory.annotation.Value
;
import
org.springframework.stereotype.Component
;
/**
* @author yangyw
*/
@Component
@Slf4j
@RequiredArgsConstructor
public
class
MilvusVectorStoreFacade
implements
VectorStoreService
{
@Value
(
"${milvus.host:172.18.5.186}"
)
private
String
host
;
@Value
(
"${milvus.port:31530}"
)
private
Integer
port
;
@Value
(
"${milvus.database:falcon}"
)
private
String
database
;
@Value
(
"${milvus.dimension:1024}"
)
private
Integer
dimension
;
@Value
(
"${milvus.collection:embedding_store}"
)
private
String
collection
;
private
EmbeddingStore
<
TextSegment
>
embeddingStore
;
@PostConstruct
protected
void
init
()
{
log
.
info
(
"开始链接milvus向量数据库:{}:{}"
,
host
,
port
);
embeddingStore
=
MilvusEmbeddingStore
.
builder
()
.
host
(
host
)
.
dimension
(
dimension
)
.
collectionName
(
collection
)
.
databaseName
(
database
)
.
port
(
port
)
.
autoFlushOnInsert
(
true
)
.
build
();
log
.
info
(
"milvus链接成功"
);
}
@Override
public
EmbeddingStore
<
TextSegment
>
getEmbeddingStore
()
{
return
embeddingStore
;
}
}
src/main/java/cn/breeze/elleai/facade/OpenAIEmbeddingFacade.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
facade
;
import
cn.breeze.elleai.application.dto.langchain.EmbeddingRequest
;
import
cn.hutool.core.collection.CollUtil
;
import
cn.hutool.core.util.StrUtil
;
import
dev.langchain4j.data.embedding.Embedding
;
import
dev.langchain4j.data.segment.TextSegment
;
import
dev.langchain4j.model.openai.OpenAiEmbeddingModel
;
import
lombok.RequiredArgsConstructor
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.beans.factory.annotation.Value
;
import
org.springframework.stereotype.Component
;
import
java.util.List
;
/**
* @author yangyw
*/
@Component
@Slf4j
@RequiredArgsConstructor
public
class
OpenAIEmbeddingFacade
implements
EmbeddingService
{
@Value
(
"${embedding.api-base-url:https://elle.e-tools.cn/v1}"
)
private
String
apiBaseUrl
;
@Value
(
"${embedding.api-key:smartbreeze}"
)
private
String
apiKey
;
@Value
(
"${embedding.model-name:bge-m3}"
)
private
String
modelName
;
/**
* 获取OpenAiEmbeddingModel
* @param request
* @return
*/
private
OpenAiEmbeddingModel
getModel
(
EmbeddingRequest
request
)
{
return
OpenAiEmbeddingModel
.
builder
()
.
modelName
(
StrUtil
.
blankToDefault
(
request
.
getModel
(),
modelName
))
.
baseUrl
(
StrUtil
.
blankToDefault
(
request
.
getApiBaseUrl
(),
apiBaseUrl
))
.
apiKey
(
StrUtil
.
blankToDefault
(
request
.
getApiKey
(),
apiKey
))
.
build
();
}
@Override
public
List
<
Embedding
>
embed
(
List
<
String
>
texts
)
{
EmbeddingRequest
embeddingRequest
=
new
EmbeddingRequest
();
embeddingRequest
.
setTexts
(
texts
);
return
embed
(
embeddingRequest
);
}
@Override
public
Embedding
embed
(
String
text
)
{
return
CollUtil
.
getFirst
(
embed
(
List
.
of
(
text
)));
}
@Override
public
List
<
Embedding
>
embed
(
EmbeddingRequest
request
)
{
return
getModel
(
request
).
embedAll
(
CollUtil
.
map
(
request
.
getTexts
(),
TextSegment:
:
from
,
true
)).
content
();
}
}
src/main/java/cn/breeze/elleai/facade/RerankFacade.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
facade
;
import
cn.breeze.elleai.application.dto.langchain.RerankItem
;
import
cn.breeze.elleai.application.dto.langchain.RerankRequest
;
import
cn.breeze.elleai.application.dto.langchain.RerankResponse
;
import
cn.breeze.elleai.application.dto.langchain.Usage
;
import
cn.hutool.core.collection.CollUtil
;
import
cn.hutool.core.util.ObjectUtil
;
import
cn.hutool.core.util.StrUtil
;
import
dev.langchain4j.data.segment.TextSegment
;
import
dev.langchain4j.model.cohere.NoBillCohereScoringModel
;
import
dev.langchain4j.model.output.Response
;
import
dev.langchain4j.model.scoring.ScoringModel
;
import
lombok.extern.slf4j.Slf4j
;
import
org.springframework.beans.factory.annotation.Value
;
import
org.springframework.stereotype.Component
;
import
java.util.List
;
/**
* 知识重排
*/
@Component
@Slf4j
public
class
RerankFacade
{
@Value
(
"${rerank.api-base-url:https://elle.e-tools.cn/v1}"
)
private
String
apiBaseUrl
;
@Value
(
"${rerank.api-key:smartbreeze}"
)
private
String
apiKey
;
@Value
(
"${rerank.model-name:bge-reranker-v2-m3}"
)
private
String
modelName
;
@Value
(
"${rerank.top-n:5}"
)
private
Integer
topN
;
private
ScoringModel
getScoringModel
(
RerankRequest
request
)
{
return
NoBillCohereScoringModel
.
builder
()
.
modelName
(
StrUtil
.
blankToDefault
(
request
.
getModel
(),
modelName
))
.
apiKey
(
StrUtil
.
blankToDefault
(
request
.
getApiKey
(),
apiKey
))
.
baseUrl
(
StrUtil
.
blankToDefault
(
request
.
getApiBaseUrl
(),
apiBaseUrl
))
.
build
();
}
/**
* 重排
* @param request
* @return
*/
public
RerankResponse
rerank
(
RerankRequest
request
)
{
ScoringModel
scoringModel
=
getScoringModel
(
request
);
long
start
=
System
.
currentTimeMillis
();
Response
<
List
<
Double
>>
response
=
scoringModel
.
scoreAll
(
CollUtil
.
map
(
request
.
getDocuments
(),
TextSegment:
:
from
,
true
),
request
.
getQuery
());
RerankResponse
rerankResponse
=
new
RerankResponse
();
if
(
ObjectUtil
.
isNotNull
(
response
)
&&
CollUtil
.
isNotEmpty
(
response
.
content
()))
{
//判断是否有Usage
if
(
ObjectUtil
.
isNotNull
(
response
.
tokenUsage
()))
{
Usage
usage
=
new
Usage
();
usage
.
setTotalTokens
(
ObjectUtil
.
defaultIfNull
(
response
.
tokenUsage
().
totalTokenCount
(),
0
));
usage
.
setPromptTokens
(
ObjectUtil
.
defaultIfNull
(
response
.
tokenUsage
().
inputTokenCount
(),
0
));
usage
.
setCompletionTokens
(
ObjectUtil
.
defaultIfNull
(
response
.
tokenUsage
().
outputTokenCount
(),
0
));
rerankResponse
.
setUsage
(
usage
);
}
List
<
RerankItem
>
results
=
CollUtil
.
newArrayList
();
for
(
int
i
=
0
;
i
<
response
.
content
().
size
();
i
++)
{
Double
score
=
response
.
content
().
get
(
i
);
if
(
ObjectUtil
.
isNotNull
(
request
.
getScoreThreshold
())
&&
score
<
request
.
getScoreThreshold
())
{
continue
;
}
RerankItem
item
=
new
RerankItem
();
item
.
setIndex
(
i
);
item
.
setContent
(
request
.
getDocuments
().
get
(
i
));
item
.
setRelevanceScore
(
score
);
results
.
add
(
item
);
}
results
=
CollUtil
.
sort
(
results
,
((
o1
,
o2
)
->
o2
.
getRelevanceScore
().
compareTo
(
o1
.
getRelevanceScore
())));
Integer
topN
=
ObjectUtil
.
defaultIfNull
(
request
.
getTopN
(),
this
.
topN
);
if
(
topN
<
results
.
size
())
{
// 截取topN
rerankResponse
.
setResults
(
results
.
subList
(
0
,
topN
));
}
else
{
rerankResponse
.
setResults
(
results
);
}
}
log
.
info
(
"查询:{}, 重排耗时:{} ms"
,
request
.
getQuery
(),
System
.
currentTimeMillis
()
-
start
);
return
rerankResponse
;
}
}
src/main/java/cn/breeze/elleai/facade/VectorStoreService.java
0 → 100644
View file @
32b01ebc
This diff is collapsed.
Click to expand it.
src/main/java/dev/langchain4j/model/cohere/NoBillCohereScoringModel.java
0 → 100644
View file @
32b01ebc
package
dev
.
langchain4j
.
model
.
cohere
;
import
cn.hutool.core.util.ObjectUtil
;
import
dev.langchain4j.data.segment.TextSegment
;
import
dev.langchain4j.model.output.Response
;
import
dev.langchain4j.model.output.TokenUsage
;
import
dev.langchain4j.model.scoring.ScoringModel
;
import
lombok.Builder
;
import
java.net.Proxy
;
import
java.time.Duration
;
import
java.util.List
;
import
static
dev
.
langchain4j
.
internal
.
RetryUtils
.
withRetry
;
import
static
dev
.
langchain4j
.
internal
.
Utils
.
getOrDefault
;
import
static
dev
.
langchain4j
.
internal
.
ValidationUtils
.
ensureNotBlank
;
import
static
java
.
time
.
Duration
.
ofSeconds
;
import
static
java
.
util
.
Comparator
.
comparingInt
;
import
static
java
.
util
.
stream
.
Collectors
.
toList
;
/**
* @author yangyw
*/
public
class
NoBillCohereScoringModel
implements
ScoringModel
{
private
static
final
String
DEFAULT_BASE_URL
=
"https://api.cohere.ai/v1/"
;
private
final
CohereClient
client
;
private
final
String
modelName
;
private
final
Integer
maxRetries
;
@Builder
public
NoBillCohereScoringModel
(
String
baseUrl
,
String
apiKey
,
String
modelName
,
Duration
timeout
,
Integer
maxRetries
,
Proxy
proxy
,
Boolean
logRequests
,
Boolean
logResponses
)
{
this
.
client
=
CohereClient
.
builder
()
.
baseUrl
(
getOrDefault
(
baseUrl
,
DEFAULT_BASE_URL
))
.
apiKey
(
ensureNotBlank
(
apiKey
,
"apiKey"
))
.
timeout
(
getOrDefault
(
timeout
,
ofSeconds
(
60
)))
.
proxy
(
proxy
)
.
logRequests
(
getOrDefault
(
logRequests
,
false
))
.
logResponses
(
getOrDefault
(
logResponses
,
false
))
.
build
();
this
.
modelName
=
modelName
;
this
.
maxRetries
=
getOrDefault
(
maxRetries
,
3
);
}
/**
* @deprecated use {@code builder()} instead and explicitly set the model name and, if required, other parameters.
*/
@Deprecated
public
static
NoBillCohereScoringModel
withApiKey
(
String
apiKey
)
{
return
builder
().
apiKey
(
apiKey
).
build
();
}
@Override
public
Response
<
List
<
Double
>>
scoreAll
(
List
<
TextSegment
>
segments
,
String
query
)
{
RerankRequest
request
=
RerankRequest
.
builder
()
.
model
(
modelName
)
.
query
(
query
)
.
documents
(
segments
.
stream
()
.
map
(
TextSegment:
:
text
)
.
collect
(
toList
()))
.
build
();
RerankResponse
response
=
withRetry
(()
->
client
.
rerank
(
request
),
maxRetries
);
List
<
Double
>
scores
=
response
.
getResults
().
stream
()
.
sorted
(
comparingInt
(
Result:
:
getIndex
))
.
map
(
Result:
:
getRelevanceScore
)
.
collect
(
toList
());
TokenUsage
usage
=
new
TokenUsage
(
0
,
0
,
0
);
if
(
ObjectUtil
.
isNotNull
(
response
.
getMeta
())
&&
ObjectUtil
.
isNotNull
(
response
.
getMeta
().
getBilledUnits
()))
{
usage
=
new
TokenUsage
(
ObjectUtil
.
defaultIfNull
(
response
.
getMeta
().
getBilledUnits
().
getSearchUnits
(),
0
));
}
return
Response
.
from
(
scores
,
usage
);
}
}
src/test/java/cn/breeze/elleai/test/AIServiceTestUnit.java
0 → 100644
View file @
32b01ebc
package
cn
.
breeze
.
elleai
.
test
;
import
cn.breeze.elleai.application.dto.langchain.VectorSearchRequest
;
import
cn.breeze.elleai.application.dto.langchain.VectorSegment
;
import
cn.breeze.elleai.application.service.AIService
;
import
cn.breeze.elleai.infra.entity.KbEntity
;
import
org.junit.jupiter.api.Test
;
import
org.springframework.beans.factory.annotation.Autowired
;
import
org.springframework.boot.test.context.SpringBootTest
;
import
org.springframework.context.annotation.Profile
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.Map
;
@SpringBootTest
@Profile
(
"dev"
)
public
class
AIServiceTestUnit
{
@Autowired
private
AIService
aiService
;
/**
* 同步知识库到向量数据库, 实例只是一条,需要根据实际业务场景进行修改
*/
@Test
public
void
testSyncKBtoVectorDb
()
{
KbEntity
kbEntity
=
new
KbEntity
();
kbEntity
.
setId
(
1L
);
kbEntity
.
setQuestion
(
"问题"
);
kbEntity
.
setAnswer
(
"答案"
);
kbEntity
.
setTagId
(
2L
);
VectorSegment
vectorSegment
=
new
VectorSegment
();
vectorSegment
.
setContent
(
kbEntity
.
getQuestion
()
+
"\n"
+
kbEntity
.
getAnswer
());
Map
<
String
,
Object
>
metadata
=
new
HashMap
<>();
metadata
.
put
(
"kb_id"
,
kbEntity
.
getId
());
metadata
.
put
(
"tag_id"
,
kbEntity
.
getTagId
());
metadata
.
put
(
"question"
,
kbEntity
.
getQuestion
());
metadata
.
put
(
"answer"
,
kbEntity
.
getAnswer
());
vectorSegment
.
setMetadata
(
metadata
);
aiService
.
addVectorSegments
(
List
.
of
(
vectorSegment
));
}
/**
* 根据tagId删除向量数据库中的数据
*/
@Test
public
void
testRemoveByTagId
()
{
Map
<
String
,
Object
>
metadata
=
new
HashMap
<>();
metadata
.
put
(
"tag_id"
,
2L
);
aiService
.
removeAll
(
metadata
);
}
/**
* 根据segmentId删除向量数据库中的数据
*/
@Test
public
void
testRemoveBySegmentId
()
{
aiService
.
removeSegments
(
List
.
of
(
"1"
));
}
@Test
public
void
testSearch
()
{
VectorSegment
vectorSegment
=
new
VectorSegment
();
vectorSegment
.
setContent
(
"问题"
);
//元数据放进来的数据会作为 filter条件,支持等于或在列表中
Map
<
String
,
Object
>
metadata
=
new
HashMap
<>();
metadata
.
put
(
"tag_id"
,
2L
);
//筛选某个分类
metadata
.
put
(
"tag_id"
,
List
.
of
(
2L
,
3L
,
4L
));
//在多个分类列表中
vectorSegment
.
setMetadata
(
metadata
);
VectorSearchRequest
request
=
VectorSearchRequest
.
builder
()
.
query
(
vectorSegment
.
getContent
())
.
topK
(
10
)
.
minScore
(
0.5
)
.
metadata
(
vectorSegment
.
getMetadata
())
.
build
();
List
<
VectorSegment
>
vectorSegments
=
aiService
.
search
(
request
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment