Automate SQL Query Generation with Mistral-7B
SQL queries are essential for extracting meaningful insights from databases, but writing them can be complex and time-consuming. In this post, we’ll explore how to use the Mistral-7B language model to automatically generate SQL queries from natural language questions. This approach is particularly useful for non-technical users who may not be familiar with SQL.
Why Mistral-7B?
Mistral-7B is a powerful large language model designed to handle complex tasks, including text generation and natural language understanding. With fine-tuning and prompt engineering, it can also be used to generate SQL queries from user inputs.
Key features of its architecture include:
- Transformer-based Architecture: Like most modern LLMs, Mistral 7B is built on the transformer architecture, which is highly effective for handling sequential data such as text.
- Instruction Fine-tuning: Mistral 7B is specifically tuned to follow natural language instructions, making it well-suited for tasks where a clear question or directive is provided, such as SQL generation or conversational agents.
Step-by-Step Guide
1. Setting Up the Environment
To get started, you need to install the necessary libraries and load the Mistral-7B model. You can use HuggingFace’s transformers
library along with LangChain
for prompt management.
!pip install git+https://github.com/huggingface/transformers.git
!pip install deepspeed --upgrade
!pip install accelerate
!pip install langchain
!pip install torch
!pip install bitsandbytes
2. Loading the Mistral-7B Model
Once the dependencies are installed, we load the base Mistral-7B model and tokenizer. We use the model to handle the language generation tasks required for converting natural language to SQL.
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import LLMChain, PromptTemplate
from langchain.llms import HuggingFacePipeline
# Load the model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# Create a text-generation pipeline
pipe = pipeline(
"text-generation",
model=base_model,
tokenizer=tokenizer,
max_length=500,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.2,
)
local_llm = HuggingFacePipeline(pipeline=pipe)
3. Extracting Database Schema
To generate SQL queries, we need the structure of the database. This structure, known as the schema, can be extracted from DDL statements.
def parse_ddl(ddl):
tables = {}
current_table = None
ddl_lines = ddl.splitlines()
for line in ddl_lines:
line = line.strip()
if line.startswith("CREATE TABLE"):
match = re.search(r'CREATE TABLE (\w+)', line)
if match:
table_name = match.group(1)
tables[table_name] = []
current_table = table_name
elif current_table:
if line.endswith(");"):
current_table = None
else:
columns = re.findall(r'(\w+)', line)
if columns:
tables[current_table].extend(columns)
return tables
schema = parse_ddl(ddl)
4. Generating SQL Queries with Mistral-7B
Now, we use Mistral-7B to generate SQL queries from natural language questions. We employ LangChain to help structure the interaction between the schema and the language model.
def query_generator(schema, question):
schema_str = "\n".join([f"{table}: {', '.join(cols)}" for table, cols in schema.items()])
template = """Generate a SQL query using the following tables and columns:
{schema}
to answer the following question:
{question}.
Output Query:
"""
prompt = PromptTemplate(template=template, input_variables=["schema", "question"])
llm_chain = LLMChain(prompt=prompt, llm=local_llm)
response = llm_chain.run({"schema": schema_str, "question": question})
return response
# Example: Generating a query
question = "What are my sales in 2013?"
query = query_generator(schema, question)
print(query)
5. Matching User Questions with Column Names
Sometimes, user queries may not exactly match the column names in the schema. We use the SequenceMatcher
to ensure the user’s query aligns with the database schema, even if there are slight mismatches.
from difflib import SequenceMatcher
import re
def find_columns_match(question, input_dict):
try:
question_list = re.split(r'\s|,|\.', question)
for index, string2 in enumerate(question_list):
for string1 in input_dict:
score = SequenceMatcher(None, string1.lower(), string2.lower()).ratio() * 100
if score > 91:
question_list[index] = string1 + ","
return " ".join(question_list)
except Exception as e:
return question
Results
When the model processes a question like “What are my sales in 2013?”, it generates an SQL query using the available schema, such as:
SELECT sale_id, sale_date, sale_amount
FROM sales
WHERE sale_date BETWEEN '2013-01-01' AND '2013-12-31';
Mistral-7B can be a game-changer for anyone working with large datasets and needing SQL queries. It allows for seamless natural language interactions with databases, enabling non-technical users to extract the insights they need without writing a single line of SQL. With tools like LangChain, this setup becomes even more powerful, flexible, and accessible.