SQL Index with Many Tablesο
Demo where table contains context.
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index import GPTSQLStructStoreIndex, SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document
from llama_index.indices.struct_store import SQLContextContainerBuilder
from IPython.display import Markdown, display
Create Database Schema + Test Dataο
Here we introduce a toy scenario where there are 100 tables (too big to fit into the prompt)
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
all_table_names = ["city_stats"]
# create a ton of dummy tables
n = 100
for i in range(n):
tmp_table_name = f"tmp_table_{i}"
tmp_table = Table(
tmp_table_name,
metadata_obj,
Column(f"tmp_field_{i}_1", String(16), primary_key=True),
Column(f"tmp_field_{i}_2", Integer),
Column(f"tmp_field_{i}_3", String(16), nullable=False),
)
all_table_names.append(f"tmp_table_{i}")
metadata_obj.create_all(engine)
# print tables
metadata_obj.tables.keys()
We introduce some test data into the city_stats table
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{"city_name": "Chicago", "population": 2679000, "country": "United States"},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
with engine.connect() as connection:
cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
print(cursor.fetchall())
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]
Using GPT Index to Store Table Schema Contextο
from llama_index import GPTSQLStructStoreIndex, SQLDatabase, GPTVectorStoreIndex
from llama_index.indices.struct_store import SQLContextContainerBuilder
sql_database = SQLDatabase(engine)
sql_database.table_info
We dump the table schema information into a vector index. The vector index is stored within the context builder for future use.
# build a vector index from the table schema information
context_builder = SQLContextContainerBuilder(sql_database)
table_schema_index = context_builder.derive_index_from_context(
GPTVectorStoreIndex,
)
# NOTE: not ingesting any unstructured documents atm
index = GPTSQLStructStoreIndex.from_documents(
[],
sql_database=sql_database,
table_name="city_stats",
)
Query Indexο
Here we show a natural language query.
We first query for the right table schema. Note that we build a context container during query-time.
Given this context container, we execute the NL query against the db.
query_str = "Which city has the highest population?"
context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
context_container = context_builder.build_context_container()
INFO:root:> [query] Total LLM token usage: 135 tokens
> [query] Total LLM token usage: 135 tokens
INFO:root:> [query] Total embedding token usage: 23 tokens
> [query] Total embedding token usage: 23 tokens
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
display(Markdown(f"<b>{context_container.context_str}</b>"))
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
query_engine = index.as_query_engine(
sql_context_container=context_container
)
response = query_engine.query(query_str)
INFO:root:> Table desc str:
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
> Table desc str:
Table 'city_stats':
city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))
INFO:root:> [query] Total LLM token usage: 134 tokens
> [query] Total LLM token usage: 134 tokens
INFO:root:> [query] Total embedding token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens
We can also use codewords during the NL query!
str(response)
"[('Tokyo',)]"
response.extra_info