Giter Club home page Giter Club logo

Comments (19)

PiranavS avatar PiranavS commented on July 30, 2024 4

I'm sorry, I didn't make any inference scripts @mellahysf.

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024 2

Hello everyone,

We are delighted to announce that we have made our text2sql demo publicly available at https://github.com/RUCKBReasoning/text2sql-demo. We believe this resource could be an excellent solution to address the needs highlighted in this issue.

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024 1

@PiranavS Yes, but I really wouldn't recommend just running T5. The main component (or contribution) of RESDSQL is the cross-encoder, which inherently improves performance. Therefore, it is better to run both stages during inference.

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024

Yes, in order to run inference on your data, you should prepare your data in the form of the Spider dataset (https://yale-lily.github.io/spider), which requires the following files:

  • tables.json: a JSON file describing the schema of your databases;
  • database: a folder containing your databases (The DB engine should be SQLite);
  • dev.json: a JSON file containing the queries from users;

In addition, if you already have gold SQLs and want to evaluate the performance of the model, a dev_gold.sql file is also needed to run the evaluation script.

The point I want to emphasize is that RESDSQL generates SQL based on the questions (i.e., users' queries) and database schema. If you only have the questions, RESDSQL will not work.

I hope my answer can help you.

from resdsql.

inesriahi avatar inesriahi commented on July 30, 2024

Thank you for your response.

I want to give it the user's query and the tables schema as strings withough giving the database itself. Is it possible?

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024

Unfortunately, we don't provide such a script in this repo because our scripts combine preparing data, SQL decoding, and result evaluation.

However, it's possible and I will show you how to do it. Before you start, I don't recommend you use the NatSQL version of RESDSQL, because NatSQL is an intermediate presentation of SQL. You have to make an extra effort to translate the generated NatSQL into SQL, which requires some special post-processes.

First, prepare the input sequences in the following format: [Question][Tables + Columns][Foreign keys]. Here is an example input:

How many singers do we have? | singer : singer.singer_id , singer.name , singer.country , singer.age , singer.song_name | stadium : stadium.location , stadium.name , stadium.capacity , stadium.highest , stadium.lowest | concert : concert.theme , concert.year , concert.concert_id , concert.concert_name , concert.stadium_id | singer_in_concert : singer_in_concert.concert_id , singer_in_concert.singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id

Then, write a script to perform seq2seq inference using T5. Of course, you should load the tokenizer and model weights from our released T5 checkpoints (for more details about how to perform inference, please refer to the T5 document on HuggingFace: https://huggingface.co/docs/transformers/model_doc/t5).

Finally, the output sequence of T5 consists of two parts: the SQL skeleton and the actual SQL query, which are split by a delimiter |. Thus, the string after | is the predicted SQL. Here is an example output:

select count ( _ ) from _ | select count ( * ) from singer

from resdsql.

PoseidomWong avatar PoseidomWong commented on July 30, 2024

I also need this inference script, strongly recommend adding it into the source code.

from resdsql.

PiranavS avatar PiranavS commented on July 30, 2024

Hi @lihaoyang-ruc , as you had mentioned in the previous comment regarding running inference without script, if I just run seq2seq inference with your released t5 checkpoints I will not be using the cross encoder to rank and send in the relevant schema according to the paper. Is that the case?

from resdsql.

PiranavS avatar PiranavS commented on July 30, 2024

@lihaoyang-ruc Thanks for the reply. Will do.

from resdsql.

PiranavS avatar PiranavS commented on July 30, 2024

Hi @lihaoyang-ruc , another small doubt. As you had mentioned earlier, the inference part for RESDSQL+RATSQL remains the same except for the fact that the output will be in RATSQL format and post processing must be done to convert it into executable format. I just wanted to confirm whether the input will be in the same format as mentioned earlier for just RESDSQL :[Question][Tables + Columns][Foreign keys]

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024

@PiranavS Hi, the NatSQL version does not add foreign keys in the input sequence because the JOIN ON clauses are removed in NatSQL. Here is an input example:

How many singers do we have? | singer : singer.singer_id , singer.name , singer.country , singer.age , singer.song_name , singer.* | singer_in_concert : singer_in_concert.singer_id , singer_in_concert.concert_id , singer_in_concert.* | stadium : stadium.location , stadium.name , stadium.capacity , stadium.highest , stadium.lowest , stadium.* | concert : concert.theme , concert.year , concert.concert_id , concert.concert_name , concert.stadium_id , concert.*

from resdsql.

PiranavS avatar PiranavS commented on July 30, 2024

@lihaoyang-ruc Thank you !

from resdsql.

mellahysf avatar mellahysf commented on July 30, 2024

@PiranavS @inesriahi Hey!
Can you share please (if DONE of course) the corresponding script for inference giving the user question and database, and that predicts the SQL query? (using RESDSQL + NatSQL AND/OR just RESDSQL).
Thank you so much!

from resdsql.

vaib26 avatar vaib26 commented on July 30, 2024

Can i expose it as a rest api service, I want to run both the stages(rank-cross encoder, text2sql model) together? if yes, then how?

from resdsql.

prattsoroco avatar prattsoroco commented on July 30, 2024

If someone has already written the script can they link it here pl? Could save a lot of people effort since all of us seem to be looking for the same thing.

from resdsql.

CyanMystery avatar CyanMystery commented on July 30, 2024

只支持sqlite数据库???像mysql这种都不行吗?还有mysql 的时间类型datetime这种好像也训练不了啊

from resdsql.

lihaoyang-ruc avatar lihaoyang-ruc commented on July 30, 2024

@jiangbossB Hi, I think your question seems to be inconsistent with the topic of the current issue, please open a new issue to discuss this matter.

from resdsql.

CyanMystery avatar CyanMystery commented on July 30, 2024

hello 大家好, 我已经使用fastapi写了一个api接口,可以从前端直接发送query生成SQL,写的不好但是可以用,有需要可以根据自己业务场景改:

`

-- coding: gbk --

import sys
import os
import pydantic
import numpy as np
import copy, argparse, json, pickle
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from tokenizers import AddedToken
import torch
from torch.utils.data import DataLoader
from transformers import T5TokenizerFast, MT5ForConditionalGeneration, T5ForConditionalGeneration
from utils.load_dataset import Text2SQLDataset
ROOT_PATH = os.path.abspath(os.path.dirname(file))
sys.path.insert(0, ROOT_PATH)
from typing import List
from fastapi import FastAPI, Body
from pydantic import BaseModel
from utils.text2sql_decoding_utils import decode_natsqls
from NatSQL.table_transform import correct_primary_keys, label_disjoint_tables, re_identify_boolean_type, unifie_words, remove_start_table, analyse_same_column, seperate_col_name, table_transform, build_index_range, expand_foreign_key, build_table_network, reversed_link_back_col, bridge_table_for_many2many_relationship, enlarge_network
from NatSQL.natsql2sql.preprocess.TokenString import get_spacy_tokenizer
from NatSQL.natsql2sql.preprocess.Schema_Token import Schema_Token
from NatSQL.natsql2sql.preprocess.stemmer import MyStemmer

preprocessing

from preprocessing import get_db_schemas, normalization, extract_skeleton, get_db_contents
from schema_item_classifier import _test
from tqdm import tqdm
from text2sql_data_generator import generate_eval_ranked_dataset
from text2sql import _test as text2sql_test

from starlette.templating import Jinja2Templates

ʵÀý»¯Ò»¸öÄ£°åÒýÇæ¶ÔÏó£¬Ö¸¶¨Ä£°åËùÔÚ·¾¶

templates = Jinja2Templates(directory="templates")

µ¼ÈëRequestÉÏÏÂÎĶÔÏó£¬ÓÃÀ´ÔÚÇ°ºǫ֮́¼ä´«µÝ²ÎÊý

from starlette.requests import Request

class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text")
response: dict = pydantic.Field(..., description="Response text")

class Config:
    schema_extra = {
        "example": {
            "question": "¹¤É˱£ÏÕÈçºÎ°ìÀí£¿",
            "response": "Óɵ¥Î»»òÕß¹¤ÉËÖ°¹¤È¥ÉêÇ빤ÉËÈ϶¨£»È»ºóÁìÈ¡¹¤ÉËÈ϶¨¾ö¶¨Ê飬׼±¸ºÃÓªÒµÖ´ÕÕ¡¢Ö°¹¤µÄÉí·ÝÖ¤¡¢Ò½Áƽɷѵ¥¾ÝµÈ²Ä£¬ÔÙÈ¥Ïòµ±µØµÄÉç±£¾­°ì»ú¹¹Ìá³öÉêÇ룻×îºóÓÉÉç±£¾­°ì»ú¹¹°´¹æ¶¨½øÐÐÅâ³¥¡£",
        }
    }

async def chat(
question: str = Body(..., description="Question", example="¹¤É˱£ÏÕÊÇʲô£¿"),
):
print('=====', question)
input_dataset = [{
"db_id": "cyerp",
"query_toks_no_value": [],
"question_toks": [i for i in question.strip()],
"query_toks": [],
"question": question.strip(),
"sql": {},
"query": ""
}]
# # ²âÊÔ¼¯Îļþ dev.json
# input_dataset_path = "data/cyerp/cyerp_dev.json"
input_dataset_path = ""
# in_file
tables_file = 'data/cyerp/cyerp_tables.json'

# Êä³öÎļþ
# tables_for_natlsql = './data/cyerp/preprocessed_data/test_tables_for_cyerp_natsql.json'
tables_for_natlsql = ''
# output_filepath = "data/cyerp/preprocessed_data/test_with_probs_natsql.json"
output_filepath = ""

# out_file = "data/cyerp/preprocessed_data/resdsql_dev_natsql.json"
out_file = ""
# tables_for_natsql = "data/cyerp/preprocessed_dev_cyerp_natsql.json"
tables_for_natsql = ""


#  NatSQL/table_transform.py
table_transform_args = get_table_transform_args(tables_file, tables_for_natlsql)
# ·µ»Øout_file
# tables_for_natsql_json
table_str = run_table_transform(table_transform_args)

### preprocessing.py
opt = get_preprocessing_args(tables_file, input_dataset_path)
# ·µ»Øoutput_dataset_path  data/cyerp/preprocessed_data/preprocessed_dev_cyerp_natsql.json
preprocessed_dataset = run_preprocessing(opt, input_dataset)

#### schema_item_classifier.py
opt = get_schema_item_classifier_args(output_filepath)
opt.save_path = cross_encoder_ckpt
test_with_probs_natsql_json = run_schema_item_classifier(opt, preprocessed_dataset)


##### text2sql_data_generator.py
opt = get_text2sql_data_generator_args(output_filepath, out_file)
# Éú³É²âÊÔÊý¾Ý¼¯
resdsql_test_natsql = run_test2sql_data_generator(opt, test_with_probs_natsql_json)
print(resdsql_test_natsql)
tables = json.loads(table_str)
table_dict = dict()
for t in tables:
    table_dict[t["db_id"]] = t
dev_dataset = Text2SQLDataset(
    dir_ = resdsql_test_natsql,
    mode = "eval"
)

dev_dataloder = DataLoader(
    dev_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: x,
    drop_last=False
)
# Ô¤²â
predict_sqls = {}
for batch in tqdm(dev_dataloder):

    batch_inputs = [data[0] for data in batch]
    print("batch_inputs: ", batch_inputs)
    """['²éѯ¹¤³§±àºÅ | vwbase_eng_pcbinfo : vwbase_eng_pcbinfo.id , vwbase_eng_pcbinfo.prono , vwbase_eng_pcbinfo.floornum , vwbase_eng_pcbinfo.productnum , vwbase_eng_pcbinfo.customerno , vwbase_eng_pcbinfo.* | tbleng_pcbno : tbleng_pcbno.id , tbleng_pcbno.oldid , tbleng_pcbno.customerno , tbleng_pcbno.sampleno , tbleng_pcbno.businessid , tbleng_pcbno.* | tblbase_eng_producttype : tblbase_eng_producttype.id , tblbase_eng_producttype.typename , tblbase_eng_producttype.enter_user , tblbase_eng_producttype.enter_time , tblbase_eng_producttype.sh_status , tblbase_eng_producttype.*']"""
    batch_db_ids = [data[1] for data in batch]
    # batch_db_ids = ['cyerp']
    batch_tc_original = [data[2] for data in batch]
    print("batch_tc_original: ", batch_tc_original)

    tokenized_inputs = tokenizer(
        batch_inputs,
        return_tensors="pt",
        padding="max_length",
        max_length=512,
        truncation=True
    )

    encoder_input_ids = tokenized_inputs["input_ids"]
    encoder_input_attention_mask = tokenized_inputs["attention_mask"]
    if torch.cuda.is_available():
        encoder_input_ids = encoder_input_ids.cuda()
        encoder_input_attention_mask = encoder_input_attention_mask.cuda()

    with torch.no_grad():
        model_outputs = model.generate(
            input_ids=encoder_input_ids,
            attention_mask=encoder_input_attention_mask,
            max_length=1024,
            decoder_start_token_id=model.config.decoder_start_token_id,
            num_beams=8,
            num_return_sequences=8
        )
        model_outputs = model_outputs.view(len(batch_inputs), 8, model_outputs.shape[1])
        _sqls = decode_natsqls(
            "",
            model_outputs,
            batch_db_ids,
            batch_inputs,
            tokenizer,
            batch_tc_original,
            table_dict
        )
        predict_sqls[batch_inputs[0].split('|')[0].strip()] = _sqls
return ChatMessage(
    question=question,
    response=predict_sqls,
)

def get_table_transform_args(tables_file: str, tables_for_natsql: str):
parser = argparse.ArgumentParser()
# ÐèÒªÐ޸ĵÄ
parser.add_argument('--in_file', default=tables_file, type=str)
parser.add_argument("--out_file", default=tables_for_natsql, type=str,
help="output table.json")
parser.add_argument('--correct_col_type', action='store_true', default=True)
parser.add_argument('--remove_start_table', action='store_true', default=True)
parser.add_argument('--analyse_same_column', action='store_true', default=True)
parser.add_argument('--table_transform', action='store_true', default=True)
parser.add_argument('--correct_primary_keys', action='store_true', default=True)
parser.add_argument('--use_extra_col_types', action='store_true', default=True)
parser.add_argument('--db_path', default='./database', type=str)

# ÔÝʱĬÈÏ
parser.add_argument("--star_type", default='others', type=str,
                    help="what type of * in column type?")
parser.add_argument('--seperate_col_name', action='store_true', default=False)
parser.add_argument('--modify_column_names', action='store_true', default=False)
parser.add_argument('--use_table_name_for_star_in_col_name', action='store_true', default=False)
parser.add_argument('--force_modify_star_column_names', action='store_true', default=False)
parser.add_argument('--remove_star_from_col', action='store_true',
                    default=False)  # Only for training. If True, put star to table name.
parser.add_argument('--add_alpha_to_table', action='store_true',
                    default=False)  # For syntaxSQL. Here is True, modify_column_names is True
parser.add_argument('--recover_previous_column_content', action='store_true', default=False)
parser.add_argument('--add_star_on_first_col', action='store_true', default=False)
parser.add_argument('--add_debug_col', action='store_true', default=False)
parser.add_argument('--keepOriginal', action='store_true',
                    default=False)  # keep the original data as same as tables.json
table_transform_args = parser.parse_args()
return table_transform_args

def run_table_transform(args):
# 1. Hyper parameters
database_path = args.db_path

# 2. Prepare data
# TODO 1 ¼ÓÔØÔ¤²âÎļþtables.json
tables = json.load(open(args.in_file, 'r'))
all_words = pickle.load(open(os.path.join("./NatSQL/data/20k.pkl"), 'rb'))
new_tables = []

lstem = MyStemmer()
_tokenizer = get_spacy_tokenizer()
schemas = []
for table in tables:
    schemas.append(Schema_Token(_tokenizer, lstem, table, None))

if args.add_debug_col:
    for t in tables:
        i = 0
        t['column_names_order'] = copy.deepcopy(t['column_names'])
        for c in t['column_names_order']:
            c[0] = c[0] * 1000 + i
            i += 1
else:
    if 'column_names_order' in tables[0].keys():
        for t in tables:
            t.pop('column_names_order')

if args.correct_primary_keys:
    correct_primary_keys(tables, schemas, database_path)
    label_disjoint_tables(tables, database_path)

if args.correct_col_type:
    tables = re_identify_boolean_type(tables, args.use_extra_col_types, database_path)
    tables = unifie_words(tables)

if args.remove_start_table:
    tables = remove_start_table(tables, schemas)

if args.analyse_same_column:
    tables = analyse_same_column(tables, schemas, database_path)

if args.seperate_col_name:
    tables = seperate_col_name(tables, all_words, schemas)

if args.table_transform:
    for table, schema in zip(tables, schemas):
        nt = table_transform(table, args, schema)
        nt = build_index_range(nt)
        nt = expand_foreign_key(nt)
        nt = build_table_network(nt)
        # enlarge network is not work well with NatSQL version >= 1.3
        nt["network"] = enlarge_network(nt["network"], nt)

        nt["tc_fast"] = []
        for tctc in nt["table_column_names_original"]:
            nt["tc_fast"].append(tctc[1].lower())
        nt["table_orig_low"] = []
        for table_orig_low in nt["table_names_original"]:
            nt["table_orig_low"].append(table_orig_low.lower())
        if 'unique_fk' in nt:
            for i in range(len(nt['unique_fk'])):
                nt['unique_fk'][i] = reversed_link_back_col(nt['unique_fk'][i], nt)
        new_tables.append(nt)

    if args.correct_primary_keys:
        bridge_table_for_many2many_relationship(new_tables)

    if "old_column_names" in new_tables[0] and args.recover_previous_column_content:
        pass
    elif "old_column_names" in new_tables[0]:
        for i, table in enumerate(new_tables):
            new_tables[i].pop("old_column_names")

    if args.add_star_on_first_col:
        for i, table in enumerate(new_tables):
            new_tables[i]["column_names"].insert(0, [-1, "*"])
            new_tables[i]["column_names_original"].insert(0, [-1, "*"])
    # TODO »áÐγÉÒ»¸öеÄÎļþ Ö±½Ó·µ»Ø£¿
    # json.dump(new_tables, open(args.out_file, 'w'), indent=2)
    table_str = json.dumps(new_tables, ensure_ascii=False)
else:
    for table in tables:
        if args.keepOriginal:
            if "same_col_idxs" in table:
                table.pop("same_col_idxs")
            if "old_column_names" in table:
                table.pop("old_column_names")
        else:
            table['table_column_names_original'] = []
            table['link_back'] = []
            for it, item in enumerate(table['column_names_original']):
                table['link_back'].append([it, it])
                if item[0] >= 0:
                    table['table_column_names_original'].append(
                        [item[0], table['table_names_original'][item[0]] + "." + item[1]])
                else:
                    table['table_column_names_original'].append(item)
    # json.dump(tables, open(args.out_file, 'w'), indent=2)
    table_str = json.dumps(tables, ensure_ascii=False)
return table_str

def get_preprocessing_args(tables_file: str, input_dataset_path: str):
preprocessing_parser = argparse.ArgumentParser()
#ÐèÒªÐ޸ĵÄ
preprocessing_parser.add_argument('--mode', type=str, default="test")
preprocessing_parser.add_argument('--table_path', type=str, default=tables_file)
preprocessing_parser.add_argument('--input_dataset_path', type=str, default=input_dataset_path,
help='''
options:
./data/spider/train_spider.json
./data/spider/dev.json
''')
preprocessing_parser.add_argument('--natsql_dataset_path', type=str,
default="./NatSQL/NatSQLv1_6/train_spider-natsql.json",
help='''
options:
./NatSQL/NatSQLv1_6/train_spider-natsql.json
./NatSQL/NatSQLv1_6/dev-natsql.json
''')
# TODO ²âÊÔµÄÎļþ¸ñʽ
preprocessing_parser.add_argument('--output_dataset_path', type=str,
default="data/cyerp/preprocessed_data/preprocessed_dev_cyerp_natsql.json",
help="the filepath of preprocessed dataset.")
preprocessing_parser.add_argument('--db_path', type=str, default="./data/spider/database",
help="the filepath of database.")
preprocessing_parser.add_argument("--target_type", type=str, default="natsql",
help="sql or natsql.")

opt = preprocessing_parser.parse_args()

return opt

def main(opt, input_dataset=None):
if input_dataset is not None:
dataset = input_dataset
else:
dataset = json.load(open(opt.input_dataset_path, encoding='gbk'))
all_db_infos = json.load(open(opt.table_path))

assert opt.mode in ["train", "eval", "test"]

if opt.mode in ["train", "eval"] and opt.target_type == "natsql":
    # only train_spider.json and dev.json have corresponding natsql dataset
    natsql_dataset = json.load(open(opt.natsql_dataset_path, encoding='gbk'))
else:
    # empty natsql dataset
    natsql_dataset = [None for _ in range(len(dataset))]

db_schemas = get_db_schemas(all_db_infos)

preprocessed_dataset = []

for natsql_data, data in tqdm(zip(natsql_dataset, dataset), desc='run preprocessing...'):
    question = data["question"].replace("\u2018", "'").replace("\u2019", "'").replace("\u201c", "'").replace(
        "\u201d", "'").strip()
    db_id = data["db_id"]

    if opt.mode == "test":
        sql, norm_sql, sql_skeleton = "", "", ""
        sql_tokens = []

        natsql, norm_natsql, natsql_skeleton = "", "", ""
        natsql_used_columns, natsql_tokens = [], []
    else:
        sql = data["query"].strip()
        norm_sql = normalization(sql).strip()
        sql_skeleton = extract_skeleton(norm_sql, db_schemas[db_id]).strip()
        sql_tokens = norm_sql.split()

        if natsql_data is not None:
            natsql = natsql_data["NatSQL"].strip()
            norm_natsql = normalization(natsql).strip()
            natsql_skeleton = extract_skeleton(norm_natsql, db_schemas[db_id]).strip()
            natsql_used_columns = [token for token in norm_natsql.split() if "." in token and token != "@.@"]
            natsql_tokens = []
            for token in norm_natsql.split():
                # split table_name_original.column_name_original
                if "." in token:
                    natsql_tokens.extend(token.split("."))
                else:
                    natsql_tokens.append(token)
        else:
            natsql, norm_natsql, natsql_skeleton = "", "", ""
            natsql_used_columns, natsql_tokens = [], []

    preprocessed_data = {}
    preprocessed_data["question"] = question
    preprocessed_data["db_id"] = db_id

    preprocessed_data["sql"] = sql
    preprocessed_data["norm_sql"] = norm_sql
    preprocessed_data["sql_skeleton"] = sql_skeleton

    preprocessed_data["natsql"] = natsql
    preprocessed_data["norm_natsql"] = norm_natsql
    preprocessed_data["natsql_skeleton"] = natsql_skeleton

    preprocessed_data["db_schema"] = []
    preprocessed_data["pk"] = db_schemas[db_id]["pk"]
    preprocessed_data["fk"] = db_schemas[db_id]["fk"]
    preprocessed_data["table_labels"] = []
    preprocessed_data["column_labels"] = []

    # add database information (including table name, column name, ..., table_labels, and column labels)
    table_map_alias = {}
    for table in db_schemas[db_id]["schema_items"]:
        db_contents = get_db_contents(
            question,
            table["table_name_original"],
            table["column_names_original"],
            db_id,
            opt.db_path
        )

        preprocessed_data["db_schema"].append({
            "table_name_original": table["table_name_original"],
            "table_name": table["table_name"],
            "column_names": table["column_names"],
            "column_names_original": table["column_names_original"],
            "column_types": table["column_types"],
            "db_contents": db_contents
        })

        # extract table and column classification labels
        if opt.target_type == "sql":
            if table["table_name_original"] in sql_tokens:  # for used tables
                preprocessed_data["table_labels"].append(1)
                column_labels = []
                for column_name_original in table["column_names_original"]:
                    if column_name_original in sql_tokens or \
                            table[
                                "table_name_original"] + "." + column_name_original in sql_tokens:  # for used columns
                        column_labels.append(1)
                    else:
                        column_labels.append(0)
                preprocessed_data["column_labels"].append(column_labels)
            else:  # for unused tables and their columns
                preprocessed_data["table_labels"].append(0)
                preprocessed_data["column_labels"].append([0 for _ in range(len(table["column_names_original"]))])
        elif opt.target_type == "natsql":
            table_alias_idx = None
            if 'as' in sql_tokens:
                try:
                    table_idx = sql_tokens.index(table["table_name_original"])
                    as_idx = [i for i, char in enumerate(sql_tokens) if char == 'as']
                    if any(np.array(as_idx) > table_idx):
                        table_alias_idx = sql_tokens.index(table["table_name_original"]) + 2
                except ValueError:
                    pass
            is_in = False
            if 'as' in sql_tokens and table_alias_idx is not None:
                is_in = True
                if table["table_name_original"].lower() not in table_map_alias:
                    table_map_alias[table["table_name_original"].lower()] = sql_tokens[table_alias_idx]
            if table["table_name_original"] in natsql_tokens:  # for used tables
                preprocessed_data["table_labels"].append(1)
                column_labels = []
                global table_alias
                if is_in:
                    table_alias = table_map_alias[table["table_name_original"].lower()]
                for column_name_original in table["column_names_original"]:
                    if is_in and table_alias + "." + column_name_original in natsql_used_columns:
                        column_labels.append(1)
                    elif table[
                        "table_name_original"] + "." + column_name_original in natsql_used_columns:  # for used columns
                        column_labels.append(1)
                    else:
                        column_labels.append(0)
                preprocessed_data["column_labels"].append(column_labels)
            else:
                preprocessed_data["table_labels"].append(0)
                preprocessed_data["column_labels"].append([0 for _ in range(len(table["column_names_original"]))])
        else:
            raise ValueError("target_type should be ``sql'' or ``natsql''")

    preprocessed_dataset.append(preprocessed_data)
preprocessed_dataset_str = json.dumps(preprocessed_dataset, ensure_ascii=False)
# TODO 2. ·µ»Øoutput_dataset_path
# with open(opt.output_dataset_path, "w") as f:
#     preprocessed_dataset_str = json.dumps(preprocessed_dataset, indent=2, ensure_ascii=False)
#     f.write(preprocessed_dataset_str)
return preprocessed_dataset_str

def run_preprocessing(preprocessing_args, input_dataset):
preprocessed_dataset_str = main(preprocessing_args, input_dataset)
return json.loads(preprocessed_dataset_str)

def get_schema_item_classifier_args(test_file: str):
parser = argparse.ArgumentParser("command line arguments for fine-tuning schema item classifier.")

parser.add_argument('--batch_size', type=int, default=1,
                    help='input batch size.')
parser.add_argument('--gradient_descent_step', type=int, default=4,
                    help='perform gradient descent per "gradient_descent_step" steps.')
parser.add_argument('--device', type=str, default="0",
                    help='the id of used GPU device.')
parser.add_argument('--learning_rate', type=float, default=3e-5,
                    help='learning rate.')
# Ä£Ðͱ£´æµÄ·¾¶
parser.add_argument('--save_path', type=str, default="models/cyerp/xlm_roberta_text2natsql_schema_item_classifier",
                    help='save path of best fine-tuned model on validation set.')

parser.add_argument('--dev_filepath', type=str, default="data/cyerp/preprocessed_data/preprocessed_dev_cyerp_natsql.json",
                    help='path of pre-processed development dataset.')
parser.add_argument('--output_filepath', type=str, default=test_file,
                    help='path of the output dataset (used in eval mode).')
parser.add_argument('--use_contents', default=True, action='store_true',
                    help='whether to integrate db contents into input sequence')






parser.add_argument('--gamma', type=float, default=1.0,
                    help='gamma parameter in the focal loss. Recommended: [0.0-2.0].')
parser.add_argument('--alpha', type=float, default=1.0,
                    help='alpha parameter in the focal loss. Must between [0.0-1.0].')
parser.add_argument('--epochs', type=int, default=1,
                    help='training epochs.')
parser.add_argument('--patience', type=int, default=32,
                    help='patience step in early stopping. -1 means no early stopping.')
parser.add_argument('--seed', type=int, default=42,
                    help='random seed.')

parser.add_argument('--tensorboard_save_path', type=str, default=None,
                    help='save path of tensorboard log.')
parser.add_argument('--train_filepath', type=str, default="data/pre-processing/preprocessed_train_spider.json",
                    help='path of pre-processed training dataset.')

parser.add_argument('--model_name_or_path', type=str, default="models/cyerp/xlm_roberta_text2natsql_schema_item_classifier",
                    help='''pre-trained model name.''')

parser.add_argument('--add_fk_info', action='store_true',
                    help='whether to add [FK] tokens into input sequence')
parser.add_argument('--mode', type=str, default="test",
                    help='trian, eval or test.')

opt = parser.parse_args()

return opt

def run_schema_item_classifier(opt, dev_filepath):
print('schema_item_classifier args----- ', opt)
if opt.mode in ["eval", "test"]:
total_table_pred_probs, total_column_pred_probs = _test(opt, dev_filepath, skip_eval=False)
# TODO bug infer dataset = dev file path
# with open(opt.dev_filepath, "r", encoding='gbk') as f:
# dataset = json.load(f)
dataset = dev_filepath

    # record predicted probability
    truncated_data_info = []
    for data_id, data in enumerate(dataset):
        table_num = len(data["table_labels"])
        if table_num == len(total_table_pred_probs[data_id]):
            table_pred_probs = total_table_pred_probs[data_id]
        else:
            table_pred_probs = total_table_pred_probs[data_id] + [-1 for _ in range(
                table_num - len(total_table_pred_probs[data_id]))]

        truncated_table_ids = []
        column_pred_probs = []
        for table_id in range(table_num):
            if table_id >= len(total_column_pred_probs[data_id]):
                truncated_table_ids.append(table_id)
                column_pred_probs.append([-1 for _ in range(len(data["column_labels"][table_id]))])
                continue
            if len(total_column_pred_probs[data_id][table_id]) == len(data["column_labels"][table_id]):
                column_pred_probs.append(total_column_pred_probs[data_id][table_id])
            else:
                truncated_table_ids.append(table_id)
                truncated_column_num = len(data["column_labels"][table_id]) - len(
                    total_column_pred_probs[data_id][table_id])
                column_pred_probs.append(
                    total_column_pred_probs[data_id][table_id] + [-1 for _ in range(truncated_column_num)])

        data["column_pred_probs"] = column_pred_probs
        data["table_pred_probs"] = table_pred_probs

        if len(truncated_table_ids) > 0:
            truncated_data_info.append([data_id, truncated_table_ids])

    # additionally, we need to consider and predict discarded tables and columns
    while len(truncated_data_info) != 0:
        truncated_dataset = []
        for truncated_data_id, truncated_table_ids in truncated_data_info:
            print(dataset[truncated_data_id]["question"])
            truncated_data = copy.deepcopy(dataset[truncated_data_id])
            truncated_data["db_schema"] = [truncated_data["db_schema"][table_id] for table_id in
                                           truncated_table_ids]
            truncated_data["table_labels"] = [truncated_data["table_labels"][table_id] for table_id in
                                              truncated_table_ids]
            truncated_data["column_labels"] = [truncated_data["column_labels"][table_id] for table_id in
                                               truncated_table_ids]
            truncated_data["table_pred_probs"] = [truncated_data["table_pred_probs"][table_id] for table_id in
                                                  truncated_table_ids]
            truncated_data["column_pred_probs"] = [truncated_data["column_pred_probs"][table_id] for table_id in
                                                   truncated_table_ids]

            truncated_dataset.append(truncated_data)
        # TODO infer dev file path
        # with open("./data/cyerp/pre-processing/truncated_dataset.json", "w") as f:
        #     f.write(json.dumps(truncated_dataset, indent=2, ensure_ascii=False))
        # opt.dev_filepath = "./data/cyerp/pre-processing/truncated_dataset.json"
        # TODO _test ¼ÓÔØÄ£ÐÍ
        total_table_pred_probs, total_column_pred_probs = _test(opt, truncated_dataset, skip_eval=False)

        for data_id, data in enumerate(truncated_dataset):
            table_num = len(data["table_labels"])
            if table_num == len(total_table_pred_probs[data_id]):
                table_pred_probs = total_table_pred_probs[data_id]
            else:
                table_pred_probs = total_table_pred_probs[data_id] + [-1 for _ in range(
                    table_num - len(total_table_pred_probs[data_id]))]

            column_pred_probs = []
            for table_id in range(table_num):
                if table_id >= len(total_column_pred_probs[data_id]):
                    column_pred_probs.append([-1 for _ in range(len(data["column_labels"][table_id]))])
                    continue
                if len(total_column_pred_probs[data_id][table_id]) == len(data["column_labels"][table_id]):
                    column_pred_probs.append(total_column_pred_probs[data_id][table_id])
                else:
                    truncated_column_num = len(data["column_labels"][table_id]) - len(
                        total_column_pred_probs[data_id][table_id])
                    column_pred_probs.append(
                        total_column_pred_probs[data_id][table_id] + [-1 for _ in range(truncated_column_num)])

            # fill the predicted probability into the dataset
            truncated_data_id = truncated_data_info[data_id][0]
            truncated_table_ids = truncated_data_info[data_id][1]
            for idx, truncated_table_id in enumerate(truncated_table_ids):
                dataset[truncated_data_id]["table_pred_probs"][truncated_table_id] = table_pred_probs[idx]
                dataset[truncated_data_id]["column_pred_probs"][truncated_table_id] = column_pred_probs[idx]

        # check if there are tables and columns in the new dataset that have not yet been predicted
        truncated_data_info = []
        for data_id, data in enumerate(dataset):
            table_num = len(data["table_labels"])

            truncated_table_ids = []
            for table_id in range(table_num):
                # the current table is not predicted
                if data["table_pred_probs"][table_id] == -1:
                    truncated_table_ids.append(table_id)
                # some columns in the current table are not predicted
                if data["table_pred_probs"][table_id] != -1 and -1 in data["column_pred_probs"][table_id]:
                    truncated_table_ids.append(table_id)

            if len(truncated_table_ids) > 0:
                truncated_data_info.append([data_id, truncated_table_ids])

        # os.remove("./data/cyerp/pre-processing/truncated_dataset.json")
    # TODO infer Êä³öoutput filec
    # with open(opt.output_filepath, "w") as f:
    #     f.write(json.dumps(dataset, indent=2, ensure_ascii=False))
    test_with_probs_natsql_json = json.dumps(dataset, ensure_ascii=False)
    return test_with_probs_natsql_json

def get_text2sql_data_generator_args(test_file: str, out_file: str):
parser = argparse.ArgumentParser("command line arguments for generating the ranked dataset.")

parser.add_argument('--input_dataset_path', type=str, default=test_file,
                    help='filepath of the input dataset.')

parser.add_argument('--output_dataset_path', type=str, default=out_file,
                    help='filepath of the output dataset.')
parser.add_argument('--topk_table_num', type=int, default=3,
                    help='we only remain topk_table_num tables in the ranked dataset (k_1 in the paper).')
parser.add_argument('--topk_column_num', type=int, default=5,
                    help='we only remain topk_column_num columns for each table in the ranked dataset (k_2 in the paper).')
parser.add_argument('--mode', type=str, default="test",
                    help='type of the input dataset, options: train, eval, test.')
parser.add_argument('--noise_rate', type=float, default=0.08,
                    help='the noise rate in the ranked training dataset (needed when the mode = "train")')
parser.add_argument('--use_contents', default=True,action='store_true',
                    help='whether to add database contents in the input sequence.')
parser.add_argument('--add_fk_info', action='store_true',
                    help='whether to add foreign key in the input sequence.')
parser.add_argument('--output_skeleton', default=True, action='store_true',
                    help='whether to add skeleton in the output sequence.')
parser.add_argument("--target_type", type=str, default="natsql",
                    help="sql or natsql.")

opt = parser.parse_args()

return opt

def get_text2sql_args(out_file: str, tables_for_natsql: str):
parser = argparse.ArgumentParser("command line arguments for fine-tuning pre-trained language model.")

parser.add_argument('--batch_size', type=int, default=1,
                    help='input batch size.')
parser.add_argument('--gradient_descent_step', type=int, default=4,
                    help='perform gradient descent per "gradient_descent_step" steps.')
parser.add_argument('--device', type=str, default="0",
                    help='the id of used GPU device.')
parser.add_argument('--learning_rate', type=float, default=5e-5,
                    help='learning rate.')
parser.add_argument('--epochs', type=int, default=12,
                    help='training epochs.')
parser.add_argument('--seed', type=int, default=42,
                    help='random seed.')
parser.add_argument('--save_path', type=str, default="models/cyerp/text2natsql-mt5-base-cspider/checkpoint-27501",
                    help='save path of best fine-tuned text2sql model.')
parser.add_argument('--tensorboard_save_path', type=str, default="tensorboard_log/cyerp/text2natsql",
                    help='save path of tensorboard log.')
parser.add_argument('--mode', type=str, default="eval",
                    help='trian, eval or test.')
# TODO ²âÊÔÎļþ
parser.add_argument('--dev_filepath', type=str, default=out_file,
                    help='file path of test2sql dev set.')
parser.add_argument('--original_dev_filepath', type=str, default="data/cyerp/cyerp_dev.json",
                    help='file path of the original dev set (for registing evaluator).')
parser.add_argument('--tables_for_natsql', type=str, default=tables_for_natsql,
                    help='file path of tables_for_natsql.json.')
parser.add_argument('--num_beams', type=int, default=8,
                    help='beam size in model.generate() function.')
parser.add_argument('--num_return_sequences', type=int, default=8,
                    help='the number of returned sequences in model.generate() function (num_return_sequences <= num_beams).')
parser.add_argument("--target_type", type=str, default="natsql",
                    help="sql or natsql.")
# TODO¡¡Ä£ÐÍÔ¤²â½á¹û
parser.add_argument("--output", type=str, default="data/cyerp/predicted_cyerp_sql.txt",
                    help="save file of the predicted sqls.")




parser.add_argument('--model_name_or_path', type=str, default="models/cyerp/text2natsql-mt5-base-cspider/checkpoint-27501",
                    help=
                    '''
                    pre-trained model name. 
                    options: 
                        t5-base, https://huggingface.co/t5-base;
                        t5-large, https://huggingface.co/t5-large;
                        t5-3b, https://huggingface.co/t5-3b;
                    ''')
parser.add_argument('--use_adafactor', action='store_true',
                    help='whether to use adafactor optimizer.')

parser.add_argument('--train_filepath', type=str, default="data/preprocessed_data/resdsql_train_spider.json",
                    help='file path of test2sql training set.')

parser.add_argument('--db_path', type=str, default="database",
                    help='file path of database.')

opt = parser.parse_args()

return opt

def run_test2sql_data_generator(opt, test_with_probs_natsql_json):
resdsql_test_natsql = generate_eval_ranked_dataset(opt, test_with_probs_natsql_json)
return resdsql_test_natsql

def load_model(model_ckpt):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# tables = json.load(open(opt.tables_for_natsql,'r'))
# initialize tokenizer
tokenizer = T5TokenizerFast.from_pretrained(
model_ckpt,
add_prefix_space=True
)

if isinstance(tokenizer, T5TokenizerFast):
    tokenizer.add_tokens([AddedToken(" <="), AddedToken(" <")])

model_class = MT5ForConditionalGeneration if "mt5" in model_ckpt else T5ForConditionalGeneration

# initialize model
model = model_class.from_pretrained(model_ckpt)
if torch.cuda.is_available():
    model = model.cuda()

model.eval()

return model, tokenizer

class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description='HTTP status code')
message: str = pydantic.Field("successful", description='HTTP status message')

class Config:
    schema_extra = {
        "example": {
            "code": 200,
            "message": "success"
        }
    }

async def index(request: Request):
return templates.TemplateResponse(name="index.html", context={"request": request})

def api_start():
app = FastAPI()
# ÊÇ·ñ¿ªÆô¿çÓò
app.add_middleware(
CORSMiddleware,
allow_origins=[""],
allow_credentials=True,
allow_methods=["
"],
allow_headers=["*"],
)
# ¶¨Òå½Ó¿Ú
app.get('/', response_model=BaseResponse)(index)
app.post('/chat', response_model=ChatMessage)(chat)

# Æô¶¯·þÎñ
uvicorn.run(app, host='localhost', port=8000)

if name == 'main':
# ¼ÓÔØÄ£ÐÍ
global model, tokenizer, cross_encoder_ckpt
model_ckpt = r'F:\python_project\RESDSQL\models\cyerp\text2natsql-mt5-base-cspider\checkpoint-4326'
cross_encoder_ckpt = "models/cyerp/xlm_roberta_text2natsql_schema_item_classifier"
model, tokenizer = load_model(model_ckpt=model_ckpt)

api_start()
# # #### text2sql.py
# opt = get_text2sql_args(out_file, tables_for_natsql)
# _,_, pred_sqls = text2sql_test(opt, table_str, resdsql_test_natsql)
# print(pred_sqls)

`

from resdsql.

CyanMystery avatar CyanMystery commented on July 30, 2024

我去,代码乱套了,有点尴尬。。。。。其实只需要两个文件,tables.json 和dev.json(这个文件就是要预测的query格式,直接构造就行了)

from resdsql.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.