SQL Index Guide (Core)

This is a basic guide to LlamaIndex's SQL index capabilities. We first show how to "build" a SQL Index by extracting unstructured Wikipedia articles on cities into structured data of city/population statistics. We then show how to run text-to-SQL over these population statistics.

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index import SimpleDirectoryReader, WikipediaReader
from IPython.display import Markdown, display

Load Wikipedia Data

We use our WikipediaReader to load in data from various cities.

# install wikipedia python package
!pip install wikipedia
wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])

Create Database Schema

We use sqlalchemy, a popular SQL database toolkit, to create an empty city_stats Table

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

Build Index

We then build the SQL Index (GPTSQLStructStoreIndex). We first define our SQLDatabase abstraction (a light wrapper around SQLAlchemy).

from llama_index import GPTSQLStructStoreIndex, SQLDatabase, ServiceContext
from langchain import OpenAI
from llama_index import LLMPredictor
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="text-davinci-002"))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database.table_info
"Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))."
# NOTE: the table_name specified here is the table that you
# want to extract into from unstructured documents.
index = GPTSQLStructStoreIndex.from_documents(
    wiki_docs, 
    sql_database=sql_database, 
    table_name="city_stats",
    service_context=service_context
)
# view current table
stmt = select(
    city_stats_table.c["city_name", "population", "country"]
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)
[('Toronto', 2731571, 'Canada'), ('Berlin', 600000, 'Germany'), ('Tokyo', 13929286, 'Japan')]

Query Index

We first show how we can execute a raw SQL query, which directly executes over the table.

query_engine = index.as_query_engine(
    query_mode="sql"
)
response = query_engine.query("SELECT city_name from city_stats")
> [query] Total LLM token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))

[('Berlin',), ('Tokyo',), ('Toronto',)]

We then show a natural language query, which is translated to a SQL query under the hood with our text-to-SQL prompt.

# set Logging to DEBUG for more detailed outputs
query_engine = index.as_query_engine(
    query_mode="nl"
)
response = query_engine.query("Which city has the highest population?")
> Predicted SQL query: SELECT city_name, population
FROM city_stats
ORDER BY population DESC
LIMIT 1
> [query] Total LLM token usage: 144 tokens
> [query] Total embedding token usage: 0 tokens
display(Markdown(f"<b>{response}</b>"))

[('Tokyo', 13929286)]

# you can also fetch the raw result from SQLAlchemy! 
response.extra_info["result"]
[('Tokyo', 13929286)]

Using LangChain for Querying

Since our SQLDatabase inherits from langchain, you can also use langchain itself for querying purposes.

from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
llm = OpenAI(temperature=0)
# set Logging to DEBUG for more detailed outputs
db_chain = SQLDatabaseChain(llm=llm, database=sql_database)
db_chain.run("Which city has the highest population?")
> Entering new SQLDatabaseChain chain...
Which city has the highest population? 
SQLQuery: SELECT city_name FROM city_stats ORDER BY population DESC LIMIT 1;
SQLResult: [('Tokyo',)]
Answer: Tokyo has the highest population.
> Finished chain.
' Tokyo has the highest population.'