Commit a311f88c authored by StyleZhang's avatar StyleZhang

compute node position

parent e92bc252
...@@ -8,7 +8,7 @@ const initialNodes = [ ...@@ -8,7 +8,7 @@ const initialNodes = [
{ {
id: '1', id: '1',
type: 'custom', type: 'custom',
position: { x: 130, y: 130 }, // position: { x: 130, y: 130 },
data: { type: 'start' }, data: { type: 'start' },
}, },
{ {
...@@ -21,20 +21,20 @@ const initialNodes = [ ...@@ -21,20 +21,20 @@ const initialNodes = [
id: '3', id: '3',
type: 'custom', type: 'custom',
position: { x: 738, y: 130 }, position: { x: 738, y: 130 },
data: { type: 'llm' }, data: { type: 'llm', sortIndexInBranches: 0 },
}, },
{ {
id: '4', id: '4',
type: 'custom', type: 'custom',
position: { x: 738, y: 330 }, position: { x: 738, y: 330 },
data: { type: 'llm' }, data: { type: 'llm', sortIndexInBranches: 1 },
},
{
id: '5',
type: 'custom',
position: { x: 1100, y: 130 },
data: { type: 'llm' },
}, },
// {
// id: '5',
// type: 'custom',
// position: { x: 1100, y: 130 },
// data: { type: 'llm' },
// },
] ]
const initialEdges = [ const initialEdges = [
...@@ -44,6 +44,7 @@ const initialEdges = [ ...@@ -44,6 +44,7 @@ const initialEdges = [
source: '1', source: '1',
sourceHandle: 'source', sourceHandle: 'source',
target: '2', target: '2',
targetHandle: 'target',
}, },
{ {
id: '1', id: '1',
...@@ -51,6 +52,7 @@ const initialEdges = [ ...@@ -51,6 +52,7 @@ const initialEdges = [
source: '2', source: '2',
sourceHandle: 'condition1', sourceHandle: 'condition1',
target: '3', target: '3',
targetHandle: 'target',
}, },
{ {
id: '2', id: '2',
...@@ -58,6 +60,7 @@ const initialEdges = [ ...@@ -58,6 +60,7 @@ const initialEdges = [
source: '2', source: '2',
sourceHandle: 'condition2', sourceHandle: 'condition2',
target: '4', target: '4',
targetHandle: 'target',
}, },
] ]
......
...@@ -59,6 +59,10 @@ const NodeSelector: FC<NodeSelectorProps> = ({ ...@@ -59,6 +59,10 @@ const NodeSelector: FC<NodeSelectorProps> = ({
e.stopPropagation() e.stopPropagation()
handleOpenChange(!open) handleOpenChange(!open)
}, [open, handleOpenChange]) }, [open, handleOpenChange])
const handleSelect = useCallback((type: BlockEnum) => {
handleOpenChange(false)
onSelect(type)
}, [handleOpenChange, onSelect])
return ( return (
<PortalToFollowElem <PortalToFollowElem
...@@ -99,7 +103,7 @@ const NodeSelector: FC<NodeSelectorProps> = ({ ...@@ -99,7 +103,7 @@ const NodeSelector: FC<NodeSelectorProps> = ({
/> />
</div> </div>
</div> </div>
<Tabs onSelect={onSelect} /> <Tabs onSelect={handleSelect} />
</div> </div>
</PortalToFollowElemContent> </PortalToFollowElemContent>
</PortalToFollowElem> </PortalToFollowElem>
......
...@@ -34,6 +34,16 @@ export const NodeInitialData = { ...@@ -34,6 +34,16 @@ export const NodeInitialData = {
retrieval_mode: 'single', retrieval_mode: 'single',
}, },
[BlockEnum.IfElse]: { [BlockEnum.IfElse]: {
branches: [
{
id: 'if-true',
name: 'IS TRUE',
},
{
id: 'if-false',
name: 'IS FALSE',
},
],
type: BlockEnum.IfElse, type: BlockEnum.IfElse,
title: '', title: '',
desc: '', desc: '',
......
...@@ -14,6 +14,7 @@ import type { ...@@ -14,6 +14,7 @@ import type {
} from './types' } from './types'
import { NodeInitialData } from './constants' import { NodeInitialData } from './constants'
import { useStore } from './store' import { useStore } from './store'
import { initialNodesPosition } from './utils'
export const useWorkflow = () => { export const useWorkflow = () => {
const store = useStoreApi() const store = useStoreApi()
...@@ -43,6 +44,7 @@ export const useWorkflow = () => { ...@@ -43,6 +44,7 @@ export const useWorkflow = () => {
}) })
setEdges(newEdges) setEdges(newEdges)
}, [store]) }, [store])
const handleLeaveNode = useCallback<NodeMouseHandler>((_, node) => { const handleLeaveNode = useCallback<NodeMouseHandler>((_, node) => {
const { const {
getNodes, getNodes,
...@@ -67,6 +69,7 @@ export const useWorkflow = () => { ...@@ -67,6 +69,7 @@ export const useWorkflow = () => {
}) })
setEdges(newEdges) setEdges(newEdges)
}, [store]) }, [store])
const handleEnterEdge = useCallback<EdgeMouseHandler>((_, edge) => { const handleEnterEdge = useCallback<EdgeMouseHandler>((_, edge) => {
const { const {
edges, edges,
...@@ -79,6 +82,7 @@ export const useWorkflow = () => { ...@@ -79,6 +82,7 @@ export const useWorkflow = () => {
}) })
setEdges(newEdges) setEdges(newEdges)
}, [store]) }, [store])
const handleLeaveEdge = useCallback<EdgeMouseHandler>((_, edge) => { const handleLeaveEdge = useCallback<EdgeMouseHandler>((_, edge) => {
const { const {
edges, edges,
...@@ -91,6 +95,7 @@ export const useWorkflow = () => { ...@@ -91,6 +95,7 @@ export const useWorkflow = () => {
}) })
setEdges(newEdges) setEdges(newEdges)
}, [store]) }, [store])
const handleSelectNode = useCallback((selectNode: SelectedNode, cancelSelection?: boolean) => { const handleSelectNode = useCallback((selectNode: SelectedNode, cancelSelection?: boolean) => {
const { const {
getNodes, getNodes,
...@@ -99,20 +104,18 @@ export const useWorkflow = () => { ...@@ -99,20 +104,18 @@ export const useWorkflow = () => {
if (cancelSelection) { if (cancelSelection) {
setSelectedNode(null) setSelectedNode(null)
const newNodes = produce(getNodes(), (draft) => { const newNodes = produce(getNodes(), (draft) => {
const currentNode = draft.find(n => n.id === selectNode.id) draft.forEach((item) => {
item.data = { ...item.data, selected: false }
if (currentNode) })
currentNode.data = { ...currentNode.data, selected: false }
}) })
setNodes(newNodes) setNodes(newNodes)
} }
else { else {
setSelectedNode(selectNode) setSelectedNode(selectNode)
const newNodes = produce(getNodes(), (draft) => { const newNodes = produce(getNodes(), (draft) => {
const currentNode = draft.find(n => n.id === selectNode.id) draft.forEach((item) => {
item.data = { ...item.data, selected: item.id === selectNode.id }
if (currentNode) })
currentNode.data = { ...currentNode.data, selected: true }
}) })
setNodes(newNodes) setNodes(newNodes)
} }
...@@ -130,7 +133,8 @@ export const useWorkflow = () => { ...@@ -130,7 +133,8 @@ export const useWorkflow = () => {
setNodes(newNodes) setNodes(newNodes)
setSelectedNode({ id, data }) setSelectedNode({ id, data })
}, [store, setSelectedNode]) }, [store, setSelectedNode])
const handleAddNextNode = useCallback((currentNodeId: string, nodeType: BlockEnum) => {
const handleAddNextNode = useCallback((currentNodeId: string, nodeType: BlockEnum, branchId?: string) => {
const { const {
getNodes, getNodes,
setNodes, setNodes,
...@@ -141,24 +145,47 @@ export const useWorkflow = () => { ...@@ -141,24 +145,47 @@ export const useWorkflow = () => {
const currentNode = nodes.find(node => node.id === currentNodeId)! const currentNode = nodes.find(node => node.id === currentNodeId)!
const nextNode = { const nextNode = {
id: `${Date.now()}`, id: `${Date.now()}`,
data: NodeInitialData[nodeType], type: 'custom',
data: { ...NodeInitialData[nodeType], selected: true },
position: { position: {
x: currentNode.position.x + 304, x: currentNode.position.x + 304,
y: currentNode.position.y, y: currentNode.position.y,
}, },
} }
const newNodes = produce(nodes, (draft) => { const newNodes = produce(nodes, (draft) => {
draft.forEach((item) => {
item.data = { ...item.data, selected: false }
})
draft.push(nextNode) draft.push(nextNode)
}) })
setNodes(newNodes) setNodes(newNodes)
const newEdges = produce(edges, (draft) => { const newEdges = produce(edges, (draft) => {
draft.push({ draft.push({
id: `${currentNode.id}-${nextNode.id}`, id: `${currentNode.id}-${nextNode.id}`,
type: 'custom',
source: currentNode.id, source: currentNode.id,
sourceHandle: branchId || 'source',
target: nextNode.id, target: nextNode.id,
targetHandle: 'target',
}) })
}) })
setEdges(newEdges) setEdges(newEdges)
setSelectedNode(nextNode)
}, [store, setSelectedNode])
const handleInitialLayoutNodes = useCallback(() => {
const {
getNodes,
setNodes,
edges,
setEdges,
} = store.getState()
setNodes(initialNodesPosition(getNodes(), edges))
setEdges(produce(edges, (draft) => {
draft.forEach((edge) => {
edge.hidden = false
})
}))
}, [store]) }, [store])
return { return {
...@@ -169,5 +196,6 @@ export const useWorkflow = () => { ...@@ -169,5 +196,6 @@ export const useWorkflow = () => {
handleSelectNode, handleSelectNode,
handleUpdateNodeData, handleUpdateNodeData,
handleAddNextNode, handleAddNextNode,
handleInitialLayoutNodes,
} }
} }
import type { FC } from 'react' import type { FC } from 'react'
import { memo, useEffect } from 'react' import {
memo,
useEffect,
useMemo,
} from 'react'
import produce from 'immer'
import type { Edge } from 'reactflow' import type { Edge } from 'reactflow'
import ReactFlow, { import ReactFlow, {
Background, Background,
ReactFlowProvider, ReactFlowProvider,
useEdgesState, useEdgesState,
useNodesInitialized,
useNodesState, useNodesState,
} from 'reactflow' } from 'reactflow'
import 'reactflow/dist/style.css' import 'reactflow/dist/style.css'
...@@ -15,7 +21,7 @@ import ZoomInOut from './zoom-in-out' ...@@ -15,7 +21,7 @@ import ZoomInOut from './zoom-in-out'
import CustomEdge from './custom-edge' import CustomEdge from './custom-edge'
import CustomConnectionLine from './custom-connection-line' import CustomConnectionLine from './custom-connection-line'
import Panel from './panel' import Panel from './panel'
import type { Node } from './types' import { BlockEnum, type Node } from './types'
const nodeTypes = { const nodeTypes = {
custom: CustomNode, custom: CustomNode,
...@@ -34,8 +40,41 @@ const Workflow: FC<WorkflowProps> = memo(({ ...@@ -34,8 +40,41 @@ const Workflow: FC<WorkflowProps> = memo(({
edges: initialEdges, edges: initialEdges,
selectedNodeId: initialSelectedNodeId, selectedNodeId: initialSelectedNodeId,
}) => { }) => {
const [nodes] = useNodesState(initialNodes) const initialData: {
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges) nodes: Node[]
edges: Edge[]
needUpdatePosition: boolean
} = useMemo(() => {
const start = initialNodes.find(node => node.data.type === BlockEnum.Start)
if (start?.position) {
return {
nodes: initialNodes,
edges: initialEdges,
needUpdatePosition: false,
}
}
return {
nodes: produce(initialNodes, (draft) => {
draft.forEach((node) => {
node.position = { x: 0, y: 0 }
node.data = { ...node.data, hidden: true }
})
}),
edges: produce(initialEdges, (draft) => {
draft.forEach((edge) => {
edge.hidden = true
})
}),
needUpdatePosition: true,
}
}, [initialNodes, initialEdges])
const nodesInitialized = useNodesInitialized({
includeHiddenNodes: true,
})
const [nodes] = useNodesState(initialData.nodes)
const [edges, setEdges, onEdgesChange] = useEdgesState(initialData.edges)
const { const {
handleEnterNode, handleEnterNode,
...@@ -43,8 +82,14 @@ const Workflow: FC<WorkflowProps> = memo(({ ...@@ -43,8 +82,14 @@ const Workflow: FC<WorkflowProps> = memo(({
handleEnterEdge, handleEnterEdge,
handleLeaveEdge, handleLeaveEdge,
handleSelectNode, handleSelectNode,
handleInitialLayoutNodes,
} = useWorkflow() } = useWorkflow()
useEffect(() => {
if (nodesInitialized && initialData.needUpdatePosition)
handleInitialLayoutNodes()
}, [nodesInitialized])
useEffect(() => { useEffect(() => {
if (initialSelectedNodeId) { if (initialSelectedNodeId) {
const initialSelectedNode = nodes.find(n => n.id === initialSelectedNodeId) const initialSelectedNode = nodes.find(n => n.id === initialSelectedNodeId)
......
...@@ -7,12 +7,12 @@ import { ...@@ -7,12 +7,12 @@ import {
Handle, Handle,
Position, Position,
getConnectedEdges, getConnectedEdges,
getIncomers,
useStoreApi, useStoreApi,
} from 'reactflow' } from 'reactflow'
import { BlockEnum } from '../../../types' import { BlockEnum } from '../../../types'
import type { Node } from '../../../types' import type { Node } from '../../../types'
import BlockSelector from '../../../block-selector' import BlockSelector from '../../../block-selector'
import { useWorkflow } from '../../../hooks'
type NodeHandleProps = { type NodeHandleProps = {
handleId?: string handleId?: string
...@@ -29,12 +29,13 @@ export const NodeTargetHandle = ({ ...@@ -29,12 +29,13 @@ export const NodeTargetHandle = ({
}: NodeHandleProps) => { }: NodeHandleProps) => {
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const store = useStoreApi() const store = useStoreApi()
const incomers = getIncomers({ id } as Node, store.getState().getNodes(), store.getState().edges) const connectedEdges = getConnectedEdges([{ id } as Node], store.getState().edges)
const connected = connectedEdges.find(edge => edge.targetHandle === handleId && edge.target === id)
const handleOpenChange = useCallback((v: boolean) => { const handleOpenChange = useCallback((v: boolean) => {
setOpen(v) setOpen(v)
}, []) }, [])
const handleHandleClick = () => { const handleHandleClick = () => {
if (incomers.length === 0 && data.type !== BlockEnum.Start) if (!connected)
handleOpenChange(!open) handleOpenChange(!open)
} }
...@@ -47,7 +48,7 @@ export const NodeTargetHandle = ({ ...@@ -47,7 +48,7 @@ export const NodeTargetHandle = ({
className={` className={`
!w-4 !h-4 !bg-transparent !rounded-none !outline-none !border-none !translate-y-0 z-[1] !w-4 !h-4 !bg-transparent !rounded-none !outline-none !border-none !translate-y-0 z-[1]
after:absolute after:w-0.5 after:h-2 after:left-1.5 after:top-1 after:bg-primary-500 after:absolute after:w-0.5 after:h-2 after:left-1.5 after:top-1 after:bg-primary-500
${!incomers.length && 'after:opacity-0'} ${!connected && 'after:opacity-0'}
${data.type === BlockEnum.Start && 'opacity-0'} ${data.type === BlockEnum.Start && 'opacity-0'}
${handleClassName} ${handleClassName}
`} `}
...@@ -55,7 +56,7 @@ export const NodeTargetHandle = ({ ...@@ -55,7 +56,7 @@ export const NodeTargetHandle = ({
onClick={handleHandleClick} onClick={handleHandleClick}
> >
{ {
incomers.length === 0 && data.type !== BlockEnum.Start && ( !connected && data.type !== BlockEnum.Start && (
<BlockSelector <BlockSelector
open={open} open={open}
onOpenChange={handleOpenChange} onOpenChange={handleOpenChange}
...@@ -84,9 +85,10 @@ export const NodeSourceHandle = ({ ...@@ -84,9 +85,10 @@ export const NodeSourceHandle = ({
nodeSelectorClassName, nodeSelectorClassName,
}: NodeHandleProps) => { }: NodeHandleProps) => {
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const { handleAddNextNode } = useWorkflow()
const store = useStoreApi() const store = useStoreApi()
const connectedEdges = getConnectedEdges([{ id } as Node], store.getState().edges) const connectedEdges = getConnectedEdges([{ id } as Node], store.getState().edges)
const connected = connectedEdges.find(edge => edge.sourceHandle === handleId) const connected = connectedEdges.find(edge => edge.sourceHandle === handleId && edge.source === id)
const handleOpenChange = useCallback((v: boolean) => { const handleOpenChange = useCallback((v: boolean) => {
setOpen(v) setOpen(v)
}, []) }, [])
...@@ -94,6 +96,9 @@ export const NodeSourceHandle = ({ ...@@ -94,6 +96,9 @@ export const NodeSourceHandle = ({
if (!connected) if (!connected)
handleOpenChange(!open) handleOpenChange(!open)
} }
const handleSelect = useCallback((type: BlockEnum) => {
handleAddNextNode(id, type)
}, [handleAddNextNode, id])
return ( return (
<> <>
...@@ -114,7 +119,7 @@ export const NodeSourceHandle = ({ ...@@ -114,7 +119,7 @@ export const NodeSourceHandle = ({
<BlockSelector <BlockSelector
open={open} open={open}
onOpenChange={handleOpenChange} onOpenChange={handleOpenChange}
onSelect={() => {}} onSelect={handleSelect}
asChild asChild
triggerClassName={open => ` triggerClassName={open => `
hidden absolute top-0 left-0 pointer-events-none hidden absolute top-0 left-0 pointer-events-none
......
...@@ -18,7 +18,6 @@ type BaseNodeProps = { ...@@ -18,7 +18,6 @@ type BaseNodeProps = {
const BaseNode: FC<BaseNodeProps> = ({ const BaseNode: FC<BaseNodeProps> = ({
id: nodeId, id: nodeId,
data, data,
selected,
children, children,
}) => { }) => {
const { handleSelectNode } = useWorkflow() const { handleSelectNode } = useWorkflow()
...@@ -28,7 +27,8 @@ const BaseNode: FC<BaseNodeProps> = ({ ...@@ -28,7 +27,8 @@ const BaseNode: FC<BaseNodeProps> = ({
className={` className={`
group relative w-[240px] bg-[#fcfdff] rounded-2xl shadow-xs group relative w-[240px] bg-[#fcfdff] rounded-2xl shadow-xs
hover:shadow-lg hover:shadow-lg
${(data.selected && selected) ? 'border-[2px] border-primary-600' : 'border border-white'} ${data.hidden && 'opacity-0'}
${data.selected ? 'border-[2px] border-primary-600' : 'border border-white'}
`} `}
onClick={() => handleSelectNode({ id: nodeId, data })} onClick={() => handleSelectNode({ id: nodeId, data })}
> >
......
import type { Node as ReactFlowNode } from 'reactflow' import type {
Edge as ReactFlowEdge,
Node as ReactFlowNode,
} from 'reactflow'
export enum BlockEnum { export enum BlockEnum {
Start = 'start', Start = 'start',
...@@ -15,15 +18,29 @@ export enum BlockEnum { ...@@ -15,15 +18,29 @@ export enum BlockEnum {
Tool = 'tool', Tool = 'tool',
} }
export type Branch = {
id: string
name: string
}
export type CommonNodeType = { export type CommonNodeType = {
hidden?: boolean
position?: {
x: number
y: number
}
sortIndexInBranches?: number
selected?: boolean
hovering?: boolean
branches?: Branch[]
title: string title: string
desc: string desc: string
type: BlockEnum type: BlockEnum
selected?: boolean
} }
export type Node = ReactFlowNode<CommonNodeType> export type Node = ReactFlowNode<CommonNodeType>
export type SelectedNode = Pick<Node, 'id' | 'data'> export type SelectedNode = Pick<Node, 'id' | 'data'>
export type Edge = ReactFlowEdge
export type ValueSelector = string[] // [nodeId, key | obj key path] export type ValueSelector = string[] // [nodeId, key | obj key path]
......
import {
getOutgoers,
} from 'reactflow'
import { cloneDeep } from 'lodash-es'
import type {
Edge,
Node,
} from './types'
import { BlockEnum } from './types'
export const initialNodesPosition = (oldNodes: Node[], edges: Edge[]) => {
const nodes = cloneDeep(oldNodes)
const start = nodes.find(node => node.data.type === BlockEnum.Start)!
start.data.hidden = false
start.position.x = 0
start.position.y = 0
start.data.position = {
x: 0,
y: 0,
}
const queue = [start]
let depth = 0
let breadth = 0
let baseHeight = 0
while (queue.length) {
const node = queue.shift()!
if (node.data.position?.x !== depth) {
breadth = 0
baseHeight = 0
}
depth = node.data.position?.x || 0
const outgoers = getOutgoers(node, nodes, edges).sort((a, b) => (a.data.sortIndexInBranches || 0) - (b.data.sortIndexInBranches || 0))
if (outgoers.length) {
queue.push(...outgoers.map((outgoer) => {
outgoer.data.hidden = false
outgoer.data.position = {
x: depth + 1,
y: breadth,
}
outgoer.position.x = (depth + 1) * (220 + 64)
outgoer.position.y = baseHeight
baseHeight += ((outgoer.height || 0) + 39)
breadth += 1
return outgoer
}))
}
}
return nodes
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment