SQL Index with Many Tables

Demo where table contains context.

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index import GPTSQLStructStoreIndex, SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document
from llama_index.indices.struct_store import SQLContextContainerBuilder
from IPython.display import Markdown, display

Create Database Schema + Test Data

Here we introduce a toy scenario where there are 100 tables (too big to fit into the prompt)

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
all_table_names = ["city_stats"]
# create a ton of dummy tables
n = 100
for i in range(n):
    tmp_table_name = f"tmp_table_{i}"
    tmp_table = Table(
        tmp_table_name,
        metadata_obj,
        Column(f"tmp_field_{i}_1", String(16), primary_key=True),
        Column(f"tmp_field_{i}_2", Integer),
        Column(f"tmp_field_{i}_3", String(16), nullable=False),
    )
    all_table_names.append(f"tmp_table_{i}")

metadata_obj.create_all(engine)
# print tables
metadata_obj.tables.keys()

We introduce some test data into the city_stats table

from sqlalchemy import insert
rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]

Using GPT Index to Store Table Schema Context

from llama_index import GPTSQLStructStoreIndex, SQLDatabase, GPTVectorStoreIndex
from llama_index.indices.struct_store import SQLContextContainerBuilder
sql_database = SQLDatabase(engine)
sql_database.table_info

We dump the table schema information into a vector index. The vector index is stored within the context builder for future use.

# build a vector index from the table schema information
context_builder = SQLContextContainerBuilder(sql_database)
table_schema_index = context_builder.derive_index_from_context(
    GPTVectorStoreIndex,
)
# NOTE: not ingesting any unstructured documents atm
index = GPTSQLStructStoreIndex.from_documents(
    [],
    sql_database=sql_database, 
    table_name="city_stats",
)

Query Index

Here we show a natural language query.

  1. We first query for the right table schema. Note that we build a context container during query-time.

  2. Given this context container, we execute the NL query against the db.

query_str = "Which city has the highest population?"
context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
context_container = context_builder.build_context_container()
INFO:root:> [query] Total LLM token usage: 135 tokens
> [query] Total LLM token usage: 135 tokens
INFO:root:> [query] Total embedding token usage: 23 tokens
> [query] Total embedding token usage: 23 tokens

Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
display(Markdown(f"<b>{context_container.context_str}</b>"))
Table 'city_stats': city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
query_engine = index.as_query_engine(
    sql_context_container=context_container
)
response = query_engine.query(query_str)
INFO:root:> Table desc str: 
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
> Table desc str: 
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
INFO:root:> [query] Total LLM token usage: 134 tokens
> [query] Total LLM token usage: 134 tokens
INFO:root:> [query] Total embedding token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens

We can also use codewords during the NL query!

str(response)
"[('Tokyo',)]"
response.extra_info