Skip to main content

DAGs for LLM interactions

Project description

Diagraph

A graph for representing LLM interactions.

This document is aspirational; not all functionality has been implemented yet. Diagraph is very much in alpha development.

Installation

pip install diagraph

Quick Start

from diagraph import Diagraph, Depends, prompt, llm

openai.api_key = 'sk-<OPENAI_TOKEN>'

@prompt
def tell_me_a_joke():
  return 'Computer! Tell me a joke about tomatoes.'

def explanation(joke: Depends(tell_me_a_joke)) -> str:
  return openai.completion(f'Explain why the joke "{joke}" is funny.')

print(Diagraph(explanation).run().result) # 'The joke is a play on words and concepts. There are two main ideas that make it humorous...

Usage

Start with a function:

def tell_me_a_joke():
  return openai.completion([{ "role": "user", "content": 'Computer! Tell me a joke about tomatoes.' }])

print(tell_me_a_joke()) # Why did the tomato turn red? Because it saw the salad dressing!

Diagraph takes functions as arguments.

from diagraph import Diagraph

def tell_me_a_joke():
  return openai.completion('Computer! Tell me a joke about tomatoes.')

diagraph = Diagraph(tell_me_a_joke)

Your final function should be the argument.

You can index into a diagraph with the function:

diagraph[tell_me_a_joke]

> def tell_me_a_joke():
>   return openai.completion('Computer! Tell me a joke about tomatoes.')

We'll discuss indexing and slicing later.

@prompt

Decorating the function with @prompt gives it super powers.

You can leverage the @prompt decorator as an alternative to calling your LLM directly (like done above with openai.completion). When decorated with a @prompt your functions sproutes additional metadata.

.prompt

You can get back information on the prompt for a function:

from diagraph import Diagraph, prompt

@prompt
def tell_me_a_joke(input: str) -> str:
  return 'Computer! Tell me a joke about {input}.'

diagraph = Diagraph(tell_me_a_joke)
print(diagraph[tell_me_a_joke].prompt()) 
# 'Computer! Tell me a joke about {input}.'

You can pass arguments to prompt() to get a parsed prompt:

print(diagraph[tell_me_a_joke].prompt('tomatoes')) 
# 'Computer! Tell me a joke about tomatoes.'

You can pass keyword arguments too, in case you have multiple string arguments and don't want to remember the defined order:

print(diagraph[tell_me_a_joke].prompt(input='tomatoes')) 
# 'Computer! Tell me a joke about tomatoes.'

.tokens

You can also get token information for the prompt. If you run without an argument it will return token information sans string replacements:

print(diagraph[tell_me_a_joke].tokens()) # 7

You can provide arguments to get tokens for the parsed string:

print(diagraph[tell_me_a_joke].tokens(input='tomatoes')) # 8

Building a Graph

Diagraphs are intended for multi-step LLM interactions. Let's demonstrate by defining a graph of dependent interactions.

This example asks for a joke, explains the joke, then tries to make it funnier:

@prompt
def tell_me_a_joke():
  return 'Computer! Tell me a joke about tomatoes.'

@prompt
def explanation(joke: Depends(tell_me_a_joke)) -> str:
  return f'Explain why the given joke is funny: {joke}'

@prompt
def improvement(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> str:
  return f'The given joke is "{joke}". An explanation of the joke is "{explanation}". What would make the joke funnier?'

diagraph = Diagraph(improvement)

A Diagraph always receives the final function (referred to as the terminal node) as the argument. (You can pass multiple functions, each a terminal node.)

You define your dependency as a type with Depends. Depends indicates to Diagraph that a function has a dependency on another function. Diagraph automatically crawls your function dependencies and builds a graph:

<IMAGE_OF_GRAPH>

You can run a diagraph with .run():

traversal = diagraph.run()
print(traversal.result) # Why did the tomato turn red? Because it saw the salad dressing!

A Diagraph is a static representation of the interaction template. Think of it as a recipe.

A Traversal is an execution of that template. Think of it as a cooked meal. A diagraph is a relationship between functions - functions can be LLM interactions, or tool use, or just regular functions - whereas a traversal is that interaction executed, and contains the composed prompts as well as the results.

Indices and Slicing

Indexing

You can index into a Diagraph:

diagraph[0] # (tell_me_a_joke, )

Using an integer index returns a tuple representing all nodes for a specific depth (or None, if an invalid integer). Diagraph automatically hoists nodes to the top of the graph and executes them greedily and in parallel.

You can also index with negative numbers:

diagraph[-1] # returns a tuple of terminal nodes

You can also index with a function:

diagraph[tell_me_a_joke] # tell_me_a_joke

Using a function index always returns a DiagraphNode representing the function.

A DiagraphNode decorates the function with additional properties, including .prompt and .tokens (if the function has been decorated with @prompt) as well as .result (if within an executed Traversal).

The same indexing strategies work for a Traversal:

traversal[tell_me_a_joke] # tell_me_a_joke

In a traversal, you can get the result for a specific function:

print(traversal[tell_me_a_joke].tokens) # 1234

Slicing

You can also slice a diagraph or traversal:

traversal[1:] # Give me the graph without the first layer of nodes

You can also use functions for slicing:

traversal[tell_me_a_joke:-1]

Functions and integers slicing work slightly differently.

If you slice with two functions, you get a subgraph for the graph connecting the two nodes.

<EXAMPLE_IMAGE>

If you slice with two integers, you get all the layers for those depths:

<EXAMPLE_IMAGE>

You can use a combination of function and integer in your slice. So for example:

<EXAMPLE_IMAGE>

length

You can also call len:

len(traversal) # 3

len will get back the total number of layers, not the total number of nodes. If you want the total number of nodes, do:

len(traversal.nodes)

Return types

You can do automatic return type validation:

from pydantic import BaseModel

class ImprovementReturnType(BaseModel):
  message: str

@prompt
def improvement(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> ImprovementReturnType:
  return f'The given joke is "{joke}". An explanation of the joke is "{explanation}". What would make the joke funnier?'

Responses are automatically parsed in the following manner:

  • str: Returns the LLM response as plain text
  • PydanticModel - returns a structured response, likely a response from a functions call. (If there is additional content, that will be ignored)
  • tuple(Optional[str], Optional[PydanticModel]) - Both content and function call, if they are present.

Responses can also be functions:

from pydantic import BaseModel

def is_it_text(joke: str):
  if isinstance(joke, str):
    return joke

  raise Exception('Invalid input!')

@prompt
def improvement(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> is_it_text:
  return f'The given joke is "{joke}".

Return types can themselves be calls to an LLM, and will be automatically rolled up into the Diagraph:

from pydantic import BaseModel

def raise_if_not_funny(is_funny: str):
  if [is_funny not in ['yes', 'no']]:
    raise Exception('Invalid Response')

  if is_funny == 'no':
    raise Exception('Not funny')

  pass

@prompt
def is_it_funny(joke: str) -> raise_if_not_funny:
  return 'Is this joke funny? Respond with yes or no: {joke}'

@prompt
def joke() -> is_it_funny:
  return f'Tell me a joke'

You can use @Depends in a return type; Diagraph will build the graph and raise if the graph cannot be solved.

Errors

You can handle errors with a decorator:

def error_handler(error):
  pass

@prompt(error=error_handler)
def improvement(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> ImprovementReturnType:
  return f'The given joke is "{joke}". An explanation of the joke is "{explanation}". What would make the joke funnier?'

An error can be of the following types:

  • ReturnTypeValidationFailure - the return type could not be satisfied

Logs

You can handle streaming logs as an argument to @prompt:

@prompt(log=lambda event, chunk, fn: print(chunk))
def improvement(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> ImprovementReturnType:
  return f'The given joke is "{joke}". An explanation of the joke is "{explanation}". What would make the joke funnier?'
  • event is an enum that can be start, data, or end.
  • chunk is the bit of text returned from the LLM. Will be None for start and end.
  • fn is the node being operated upon.

Time Travel & Replay

A traversal contains the whole history of your interaction.

This means you can rerun graphs from an intermediary point:

traversal.run(explanation)

# _or_

traversal[explanation].run()

When re-running part of a graph, all upstream nodes remain cached, whereas all downstream nodes are re-executed. Alternatively, you can run a single node by passing the only flag:

traversal.run(explanation, only=True)

# _or_

traversal[explanation].run(only=True)

In this case, only the explanation function will be executed; however, all downstream node caches will be cleared.

run() returns a new instance of a traversal:

traversal2 = traversal[explanation].run()

You can also specify indices:

traversal[-2].run()

This will re-execute all nodes for a given depth.

Control Flows

Diagraphs are meant for statically analyzable interactions. They're not intended to support more complicated control flows like if/else statements, while loops, etc.

However, since everything is just a Python function, you can implement such behavior yourself:

def the_joke_maker():
  def tell_me_a_joke():
    return openai.completion('Computer! Tell me a joke about tomatoes.')

  def explanation(joke: Depends(tell_me_a_joke)) -> str:
    return openai.completion(f'Explain why the given joke is funny: {joke}')

  def is_it_funny(joke: Depends(tell_me_a_joke), explanation: Depends(explanation)) -> str:
    return openai.completion(f'The given joke is "{joke}". An explanation of the joke is "{explanation}". Is it funny to a kindergartner?')

  diagraph = Diagraph(improvement)
  is_funny = False
  traversal = None
  while is_funny is False:
    if traversal is None:
      traversal = diagraph.run()
    else:
      traversal.run(tell_me_a_joke)
    is_funny = traversal.output

  return traversal[tell_me_a_joke].output

You can in turn wrap this function into a Diagraph:

def shakespearify(joke: Depends(the_joke_maker)) -> str:
  return openai.completion(f'rewrite this as shakespeare: {joke}')

Diagraph(shakespearify)

Changing LLM model

gpt4 = OpenAI(model='gpt-4', temperature=0.2)

@prompt(llm=gpt4)
def fn():
  return 'tell me a joke'

Integrations

  • Weights & Biases
  • Mermaidjs
  • Magnetic
  • Marvin
  • LLM (Simon's library)
  • Langchain
  • LlamaIndex

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diagraph-0.1.8.tar.gz (28.4 kB view hashes)

Uploaded Source

Built Distribution

diagraph-0.1.8-py3-none-any.whl (32.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page