Overview
StateGraph provides advanced features for complex workflows:
- 🔀 Send API - Dynamic parallel execution (orchestrator-worker pattern)
- ⚡ Parallel Nodes - Automatic concurrent execution
- 🎯 Task Decorators - Durable functions with retry and cache
- 📊 Schemas - Type-safe input/output contracts
Send API
The Send API enables dynamic parallelization where a node can spawn multiple worker instances that execute concurrently.
Basic Send Pattern
from typing import Annotated, List
from typing_extensions import TypedDict
import operator
from upsonic.graphv2 import StateGraph, START, END, Send
class State(TypedDict):
items: List[str]
results: Annotated[List[str], operator.add]
def orchestrator(state: State) -> dict:
"""Prepare items for processing."""
return {"items": state["items"]}
def fan_out(state: State) -> List[Send]:
"""Create a worker for each item."""
return [
Send("worker", {"item": item})
for item in state["items"]
]
def worker(state: dict) -> dict:
"""Process a single item."""
item = state["item"]
result = f"processed_{item}"
return {"results": [result]}
# Build graph
builder = StateGraph(State)
builder.add_node("orchestrator", orchestrator)
builder.add_node("worker", worker)
builder.add_edge(START, "orchestrator")
builder.add_conditional_edges("orchestrator", fan_out, ["worker"])
builder.add_edge("worker", END)
graph = builder.compile()
# Execute
result = graph.invoke({
"items": ["a", "b", "c"],
"results": []
})
print(result["results"]) # ['processed_a', 'processed_b', 'processed_c']
Workers execute in parallel and their results are automatically merged using reducers.
Send with Different State
Each Send can pass different state to workers:
def distribute_work(state: State) -> List[Send]:
"""Create workers with different configurations."""
return [
Send("process", {"value": val, "multiplier": 2})
for val in state["values"]
]
def process(state: dict) -> dict:
"""Process with custom config."""
result = state["value"] * state["multiplier"]
return {"results": [result]}
Multiple Worker Types
Route to different workers based on logic:
class TaskState(TypedDict):
tasks: List[dict]
fast_results: Annotated[List[str], operator.add]
slow_results: Annotated[List[str], operator.add]
def route_tasks(state: TaskState) -> List[Send]:
"""Route to different workers based on task type."""
sends = []
for task in state["tasks"]:
if task["priority"] == "high":
sends.append(Send("fast_worker", {"task": task}))
else:
sends.append(Send("slow_worker", {"task": task}))
return sends
def fast_worker(state: dict) -> dict:
return {"fast_results": [f"fast_{state['task']['id']}"]}
def slow_worker(state: dict) -> dict:
return {"slow_results": [f"slow_{state['task']['id']}"]}
builder = StateGraph(TaskState)
builder.add_node("route", lambda s: {"tasks": s["tasks"]})
builder.add_node("fast_worker", fast_worker)
builder.add_node("slow_worker", slow_worker)
builder.add_edge(START, "route")
builder.add_conditional_edges("route", route_tasks, ["fast_worker", "slow_worker"])
builder.add_edge("fast_worker", END)
builder.add_edge("slow_worker", END)
graph = builder.compile()
Map-Reduce Pattern
Classic map-reduce with Send:
class MapReduceState(TypedDict):
data: List[int]
mapped: Annotated[List[int], operator.add]
reduced: int
def map_phase(state: MapReduceState) -> List[Send]:
"""Map: send each item to a worker."""
return [
Send("mapper", {"value": val})
for val in state["data"]
]
def mapper(state: dict) -> dict:
"""Map function: square the value."""
squared = state["value"] ** 2
return {"mapped": [squared]}
def reduce_phase(state: MapReduceState) -> dict:
"""Reduce: sum all mapped values."""
total = sum(state["mapped"])
return {"reduced": total}
builder = StateGraph(MapReduceState)
builder.add_node("start", lambda s: {"data": s["data"]})
builder.add_node("mapper", mapper)
builder.add_node("reduce", reduce_phase)
builder.add_edge(START, "start")
builder.add_conditional_edges("start", map_phase, ["mapper"])
builder.add_edge("mapper", "reduce")
builder.add_edge("reduce", END)
graph = builder.compile()
result = graph.invoke({
"data": [1, 2, 3, 4, 5],
"mapped": [],
"reduced": 0
})
print(f"Sum of squares: {result['reduced']}") # 1+4+9+16+25 = 55
Nested Send
Send can be nested for hierarchical parallelization:
def level1_orchestrator(state: State) -> List[Send]:
"""First level of parallelization."""
return [
Send("level2_orchestrator", {"batch": batch})
for batch in state["batches"]
]
def level2_orchestrator(state: dict) -> List[Send]:
"""Second level of parallelization."""
return [
Send("worker", {"item": item})
for item in state["batch"]["items"]
]
def worker(state: dict) -> dict:
"""Leaf worker."""
return {"results": [f"processed_{state['item']}"]}
Parallel Node Execution
StateGraph automatically executes nodes in parallel when they have no dependencies:
from typing import Annotated, List
import operator
class ParallelState(TypedDict):
input: str
results_a: Annotated[List[str], operator.add]
results_b: Annotated[List[str], operator.add]
final: str
def setup(state: ParallelState) -> dict:
"""Setup node."""
return {"input": state["input"]}
def process_a(state: ParallelState) -> dict:
"""Process A - runs in parallel with B."""
result = f"A processed: {state['input']}"
return {"results_a": [result]}
def process_b(state: ParallelState) -> dict:
"""Process B - runs in parallel with A."""
result = f"B processed: {state['input']}"
return {"results_b": [result]}
def merge(state: ParallelState) -> dict:
"""Merge results from parallel nodes."""
combined = state['results_a'] + state['results_b']
return {"final": ", ".join(combined)}
# Build graph with parallel execution
builder = StateGraph(ParallelState)
builder.add_node("setup", setup)
builder.add_node("process_a", process_a)
builder.add_node("process_b", process_b)
builder.add_node("merge", merge)
# Both process_a and process_b will execute in parallel
builder.add_edge(START, "setup")
builder.add_edge("setup", "process_a")
builder.add_edge("setup", "process_b")
builder.add_edge("process_a", "merge")
builder.add_edge("process_b", "merge")
builder.add_edge("merge", END)
graph = builder.compile()
result = graph.invoke({
"input": "test data",
"results_a": [],
"results_b": [],
"final": ""
})
print(result["final"])
Automatic Parallelization: When multiple nodes have the same parent and don’t depend on each other, they execute concurrently.
Task Decorator
The @task decorator creates durable functions with built-in retry and caching:
Basic Task
from upsonic.graphv2 import task
@task
def expensive_computation(x: int) -> int:
"""A function decorated as a task."""
# Expensive operation
import time
time.sleep(1)
return x * 2
# Use in a node
def my_node(state: MyState) -> dict:
# Call the task - returns TaskResult
result = expensive_computation(state["value"]).result()
return {"output": result}
Task with Retry Policy
from upsonic.graphv2 import task, RetryPolicy
@task(retry_policy=RetryPolicy(
max_attempts=3,
initial_interval=1.0,
backoff_factor=2.0
))
def unreliable_api_call(endpoint: str) -> dict:
"""API call that might fail."""
response = requests.post(endpoint)
response.raise_for_status()
return response.json()
# Use it
def api_node(state: State) -> dict:
result = unreliable_api_call(state["endpoint"]).result()
return {"api_response": result}
Tasks automatically retry on failure with exponential backoff. Failures are logged but don’t crash your workflow.
Task with Cache Policy
from upsonic.graphv2 import task, CachePolicy, InMemoryCache
@task(cache_policy=CachePolicy(ttl=300)) # Cache for 5 minutes
def fetch_data(user_id: str) -> dict:
"""Expensive data fetch - results are cached."""
data = database.query(user_id) # Slow operation
return data
# Use with cache
def node(state: State) -> dict:
# First call - executes and caches
result1 = fetch_data("user123").result()
# Second call - uses cache
result2 = fetch_data("user123").result()
return {"data": result1}
# Provide cache to graph
cache = InMemoryCache()
graph = builder.compile(cache=cache)
Combined Retry and Cache
@task(
retry_policy=RetryPolicy(max_attempts=3),
cache_policy=CachePolicy(ttl=600)
)
def robust_expensive_call(param: str) -> str:
"""Retries on failure, caches on success."""
result = external_api.call(param)
return result
Custom Retry Logic
from upsonic.graphv2 import RetryPolicy
@task(retry_policy=RetryPolicy(
max_attempts=5,
initial_interval=0.5,
backoff_factor=2.0,
max_interval=30.0,
jitter=True,
retry_on=ConnectionError # Only retry on specific errors
))
def selective_retry(url: str) -> str:
"""Only retries on connection errors."""
return requests.get(url).text
Async Tasks
Tasks work with async functions too:
import asyncio
@task(retry_policy=RetryPolicy(max_attempts=3))
async def async_task(data: str) -> str:
"""Async task with retry."""
await asyncio.sleep(1)
return f"Processed: {data}"
# Use in async node
async def async_node(state: State) -> dict:
result = await async_task(state["data"]).aresult()
return {"output": result}
Define strict contracts for your graphs:
from typing_extensions import TypedDict
class InputState(TypedDict):
user_id: str
query: str
class OutputState(TypedDict):
response: str
timestamp: str
class InternalState(InputState, OutputState):
# Internal fields not exposed
processing_steps: int
intermediate_data: dict
cache_hit: bool
# Build with schemas
builder = StateGraph(
InternalState,
input_schema=InputState, # Validates input
output_schema=OutputState # Filters output
)
builder.add_node("process", process_node)
builder.add_edge(START, "process")
builder.add_edge("process", END)
graph = builder.compile()
# Input validation
result = graph.invoke({
"user_id": "123",
"query": "search term"
# Extra fields are allowed but not required
})
# Output filtering
print(result.keys()) # Only ['response', 'timestamp']
# Internal fields are filtered out
Input Schema: Validates that required fields are present. Missing required fields raise GraphValidationError.Output Schema: Filters the final state to only include specified fields.
Runtime Configuration
Pass runtime context to nodes:
class State(TypedDict):
input: str
output: str
def configurable_node(state: State, config: dict) -> dict:
"""Node that uses runtime configuration."""
context = config.get("context", {})
model_name = context.get("model", "default")
temperature = context.get("temperature", 0.7)
max_tokens = context.get("max_tokens", 100)
# Use configuration
model = infer_model(model_name)
result = model.invoke(
state["input"],
temperature=temperature,
max_tokens=max_tokens
)
return {"output": result}
# Pass context at runtime
result = graph.invoke(
{"input": "test"},
context={
"model": "openai/gpt-4o",
"temperature": 0.9,
"max_tokens": 500
}
)
Use context for runtime configuration like model selection, API keys, feature flags, or user preferences.
Recursion Control
Prevent infinite loops with recursion limits:
from upsonic.graphv2 import Command, END
def loop_node(state: State) -> Command:
"""Node that loops back to itself."""
count = state["count"] + 1
if count >= 100:
return Command(update={"count": count}, goto=END)
return Command(update={"count": count}, goto="loop_node")
builder.add_node("loop_node", loop_node)
builder.add_edge(START, "loop_node")
graph = builder.compile()
# Set recursion limit
result = graph.invoke(
{"count": 0},
config={"recursion_limit": 50} # Max 50 steps
)
Default recursion limit is 100. Always set appropriate limits when using loops to prevent runaway execution.
Advanced Patterns
Dynamic Graph Construction
Build graphs based on runtime data:
def build_dynamic_graph(node_count: int) -> CompiledStateGraph:
"""Build a graph with variable number of nodes."""
builder = StateGraph(MyState)
# Add dynamic number of nodes
for i in range(node_count):
builder.add_node(f"node_{i}", lambda s: {"count": s["count"] + 1})
# Connect them
builder.add_edge(START, "node_0")
for i in range(node_count - 1):
builder.add_edge(f"node_{i}", f"node_{i+1}")
builder.add_edge(f"node_{node_count-1}", END)
return builder.compile()
# Create different sized graphs
small_graph = build_dynamic_graph(3)
large_graph = build_dynamic_graph(10)
Conditional Parallel Execution
Decide at runtime whether to parallelize:
from upsonic.graphv2 import Send
def conditional_parallel(state: State) -> Union[List[Send], str]:
"""Parallelize only if needed."""
if state["parallel_mode"]:
# Parallel execution
return [
Send("worker", {"item": item})
for item in state["items"]
]
else:
# Sequential execution
return "sequential_processor"
builder.add_conditional_edges(
"decision",
conditional_parallel,
["worker", "sequential_processor"]
)
Hierarchical State Updates
Use nested state structures:
class HierarchicalState(TypedDict):
config: dict
user_data: dict
results: dict
def node(state: HierarchicalState) -> dict:
"""Update nested structures."""
new_config = {**state["config"], "updated": True}
new_user_data = {**state["user_data"], "processed": True}
return {
"config": new_config,
"user_data": new_user_data,
"results": {"status": "complete"}
}
Cache Hot Paths
# Cache frequently used computations
@task(cache_policy=CachePolicy(ttl=3600))
def expensive_feature_extraction(data: str) -> dict:
# Expensive ML inference
return model.predict(data)
Parallel Data Processing
def process_large_dataset(state: State) -> List[Send]:
"""Process data in parallel batches."""
batch_size = 100
batches = [
state["data"][i:i+batch_size]
for i in range(0, len(state["data"]), batch_size)
]
return [
Send("batch_processor", {"batch": batch})
for batch in batches
]
Async Durability
# Use async durability for better throughput
graph = builder.compile(
checkpointer=checkpointer,
durability="async" # Background persistence
)
Best Practices
1. Use Send for True Parallelism
# ✅ Good - Send for dynamic parallelization
def fan_out(state: State) -> List[Send]:
return [Send("worker", {"item": i}) for i in state["items"]]
# ⚠️ Alternative - Parallel nodes for static parallelism
# Use when you know the exact nodes at graph build time
2. Task Decorators for Side Effects
# ✅ Good - @task for external calls
@task(retry_policy=RetryPolicy(max_attempts=3))
def call_external_api(url: str) -> dict:
return requests.post(url).json()
# ❌ Bad - No retry for critical calls
def call_external_api(url: str) -> dict:
return requests.post(url).json()
3. Reasonable Cache TTLs
# ✅ Good - appropriate TTLs
@task(cache_policy=CachePolicy(ttl=300)) # 5 min for dynamic data
def get_stock_price(symbol: str) -> float:
...
@task(cache_policy=CachePolicy(ttl=86400)) # 24 hrs for static data
def get_company_info(symbol: str) -> dict:
...
4. Monitor Parallel Execution
def orchestrator(state: State) -> List[Send]:
print(f"Spawning {len(state['items'])} workers")
return [Send("worker", {"item": i}) for i in state["items"]]
def worker(state: dict) -> dict:
print(f"Processing item: {state['item']}")
return {"results": [process(state["item"])]}
Next Steps