Giter Club home page Giter Club logo

Comments (9)

chrispan68 avatar chrispan68 commented on May 9, 2024 1
import yaml
import argparse
import subprocess
import torch
import time
import asyncio
import sys
import atexit

from data import *
from metrics import *
from utils import *
from queries import *

def get_metric(metric: str, query: LMQLQuery, metric_config):
    metric = slugify(metric)
    metric_to_class_map = {
        "zero_to_few_shot_generalization": ZeroToFewShotGeneralizationMetric,
        "selection": SelectionMetric,
        "permutation": PermutationMetric
    }
    if not metric in metric_to_class_map:
        raise KeyError(f"Unrecognized metric {metric}")
    
    metric_class = metric_to_class_map[metric]
    return metric_class(query=query,
                        **metric_config[metric])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True, help="Model name")
    parser.add_argument("--metric", type=str, required=False, help="Metric name")
    parser.add_argument("--metric_config", type=str, required=True, help="Path the metric config file")
    parser.add_argument("--query_filename", type=str, required=True, help="Path the prompt lmql query")
    parser.add_argument("--type", type=str, default="classification", choices=["classification"])
    parser.add_argument("--log_dir", type=str, required=False, help="Logging directory", default="logs/results")
    parser.add_argument("--ping_interval", type=int, required=False, help="How often to ping the model before inference. Set to 0 if you don't want to ping model", default=5)
    args = parser.parse_args()

    with open(args.metric_config, 'r') as f:
        metric_config = yaml.safe_load(f)
    
    if args.type == "classification":
        query = LMQLClassificationQuery(args.query_filename, args.model)
    else:
        raise ValueError(f"Invalid task type: {args.type}")
    
    if (args.metric is not None) and (args.metric not in metric_config):
        raise AssertionError(f"{args.metric_config} is not a valid config file for metric {args.metric}.")
    
    metric_name = args.metric or list(metric_config.keys())[0]
    metric = get_metric(metric_name, query, metric_config)

    # serve the language model
    serve_model_process_args = ["python" , "-m" "lmql.model.serve", args.model]
    if torch.cuda.is_available:
        serve_model_process_args.append("--cuda")
    if args.ping_interval:
        serve_model_process_args.append("--wait_until_ready")
    serve_model_process = subprocess.Popen(serve_model_process_args)
    atexit.register(serve_model_process.terminate)
    
    # don't proceed with inference until hosted model ready.
    if args.ping_interval:
        while not lmql_model_server_running():
            print("Waiting for model server to run...")
            sys.stdout.flush()
            time.sleep(args.ping_interval)
    
    inputs = metric.create_inputs()
    result = metric.evaluate(inputs)

    serve_model_process.terminate()
    
    header = f"model: {query.model}\ndataset: {query.dataset_name}\nmetric: {metric_name}\n" \
                f"metric_config: {metric_config}\nquery_filename: {args.query_filename}"
    
    # Write results to disk
    write_results(args.log_dir, header, result)

    print(json.dumps(result))

from lmql.

chrispan68 avatar chrispan68 commented on May 9, 2024

If it's an issue with the LMQL codebase I'd be more than happy to look into it, and submit a PR?

from lmql.

lbeurerkellner avatar lbeurerkellner commented on May 9, 2024

Thanks for offering help. Definitely, feel free to explore the code base, we are accepting PRs. I will have a closer look myself, or did you already fix this in #15?

from lmql.

chrispan68 avatar chrispan68 commented on May 9, 2024

#15 doesn't fix this issue, I can look into this and submit a separate PR in the near future.

from lmql.

praveenv avatar praveenv commented on May 9, 2024

@chrispan68 Do you mind sharing a code snippet of how you implement subprocess for the model serving and the client calls? I've tried using the flag that you added in #15 , but am probably doing something wrong.
An example would be much appreciated!

from lmql.

chrispan68 avatar chrispan68 commented on May 9, 2024

In here, I spawn a model process and terminate it when the evaluation script terminates.

from lmql.

chrispan68 avatar chrispan68 commented on May 9, 2024

Let me know if this isn't relevant. Good luck!

from lmql.

praveenv avatar praveenv commented on May 9, 2024

Thank you so much!!
A followup question --
What does lmql_model_server_running() do?
Do you keep checking if port 8080 can be connected to? I was trying to do that using the socket library

from lmql.

lbeurerkellner avatar lbeurerkellner commented on May 9, 2024

Since the latest serve-model implementation no longer relies on multiprocessing and child processes, this issue should be fixed with the latest version.

from lmql.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.