Giter Club home page Giter Club logo

scikit_learn_android_demo's Introduction

Using scikit-learn Models In Android Applications

This project demonstrates the use of a scikit-learn model in an Android app using ONNX as a bridge between both the frameworks.

Demo of the app

Getting Started

  1. Clone this repository and open the resulting project in Android Studio,
>> git clone https://github.com/shubham0204/Scikit_Learn_Android_Demo
  1. Read the blog Deploying Scikit-Learn Models In Android Apps With ONNX to follow the procedure for your model. The code included in this blog is available as an Google Colab notebook.

  2. Place your .ort model in the app/src/main/res/raw folder.

scikit_learn_android_demo's People

Contributors

shubham0204 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

scikit_learn_android_demo's Issues

Need to convert to .ort format

FYI: It’s not necessary to convert to .ort format when using the ‘full’ ONNX Runtime package onnxruntime-android - you can use the onnx model.

The conversion to .ort format is only necessary if using the smaller ‘mobile’ package onnxruntime-mobile, which has limited operators/types (based on popular dnn models used in mobile scenarios) to provide a smaller binary size. That package however does not include traditional ML operators that SciKit-Learn tends to use, so most likely it wouldn't be able to run a model that was converted from SKL.

More than one inputs.

I don't know if it's the correct way to predict with many inputs or not. I think there's something wrong in "val inputTensor = OnnxTensor.createTensor( ortEnvironment , floatBufferInputs , longArrayOf( 1, 1 ) )"

class CalculatorActivity : AppCompatActivity() {
    private lateinit var binding: ActivityCalculatorBinding
    private var isGenderSelected = false
    private var previousSelectedPosition = 0
    private var selectedGenderValue = 0 // 0 for Female, 1 for Male

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        binding = ActivityCalculatorBinding.inflate(layoutInflater)
        setContentView(binding.root)

        val textInputEditTextAge = binding.tedAge
        val placeholderAge = "... Years"
        textInputEditTextAge.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextAge.hint = placeholderAge

        val textInputEditTextWeight = binding.tedWeight
        val placeholderWeight = "... KG"
        textInputEditTextWeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextWeight.hint = placeholderWeight

        val textInputEditTextHeight = binding.tedHeight
        val placeholderHeight = "... CM"
        textInputEditTextHeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHeight.hint = placeholderHeight

        val textInputEditTextHours = binding.tedHours
        val placeholderHours = "... Hours"
        textInputEditTextHours.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHours.hint = placeholderHours

        val textInputEditTextYears = binding.tedYears
        val placeholderYears = "... Years"
        textInputEditTextYears.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextYears.hint = placeholderYears

        val gender = resources.getStringArray(R.array.Gender)

        val spinner = binding.genderSpinner
        if (spinner != null) {
            val adapter = ArrayAdapter(this,
                android.R.layout.simple_spinner_dropdown_item, gender)
            spinner.adapter = adapter

            spinner.onItemSelectedListener = object :
                AdapterView.OnItemSelectedListener {
                override fun onItemSelected(parent: AdapterView<*>, view: View, position: Int, id: Long) {
                    if (isGenderSelected) {
                        // A gender other than "-- Select your Gender --" has already been selected
                        if (position == 0) {
                            spinner.setSelection(previousSelectedPosition) // Set the spinner to the previous selected position
                        } else {
                            previousSelectedPosition = position // Update the previous selected position
                            selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                        }
                    } else {
                        isGenderSelected = true
                        previousSelectedPosition = position // Set the initial selected position
                        selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                    }
                }

                override fun onNothingSelected(parent: AdapterView<*>) {
                    // write code to perform some action
                }
            }
        }

        binding.btnPredict.setOnClickListener {
            val gender = selectedGenderValue.toFloat()
            val age = textInputEditTextAge.text.toString().toFloatOrNull()
            val weight = textInputEditTextWeight.text.toString().toFloatOrNull()
            val height = textInputEditTextHeight.text.toString().toFloatOrNull()
            val hours = textInputEditTextHours.text.toString().toFloatOrNull()
            val years = textInputEditTextYears.text.toString().toFloatOrNull()
            val bmi = (weight?.div(((height?.div(100))?.times((height?.div(100)!!))!!)))
            if (gender !=null && age != null && weight != null && height != null && hours != null && years != null && bmi != null){
                val ortEnvironment = OrtEnvironment.getEnvironment()
                val ortSession = createORTSession(ortEnvironment)
                val output = runPrediction(
                    gender, age, weight, height, hours, years, bmi,
                    ortSession, ortEnvironment
                )
                showOutputPopup(output)
            } else {
                Toast.makeText(this, "Please fill in all the inputs", Toast.LENGTH_LONG).show()
            }
        }
    }

    private fun createORTSession( ortEnvironment: OrtEnvironment) : OrtSession {
        val modelBytes = resources.openRawResource( R.raw.model ).readBytes()
        return ortEnvironment.createSession( modelBytes )
    }

    private fun runPrediction( genders: Float , age: Float , weight: Float , height: Float , hours: Float , years: Float , bmi : Float , ortSession: OrtSession , ortEnvironment: OrtEnvironment ) : Float {
        // Get the name of the input node
        val inputName = ortSession.inputNames?.iterator()?.next()
        // Make a FloatBuffer of the inputs
        val floatBufferInputs = FloatBuffer.wrap( floatArrayOf( genders, age, weight, height, hours, years, bmi ) )
        // Create input tensor with floatBufferInputs of shape ( 1 , 1 )
        val inputTensor = OnnxTensor.createTensor( ortEnvironment , floatBufferInputs , longArrayOf( 1, 1 ) )
        // Run the model
        val results = ortSession.run( mapOf( inputName to inputTensor ) )
        // Fetch and return the results
        val output = results[0].value as Array<FloatArray>
        return output[0][0]
    }

    fun showOutputPopup(output: Float) {
        // Inflate the custom layout for the popup
        val inflater = layoutInflater
        val popupView = inflater.inflate(R.layout.popup_output, null)

        // Find views within the custom layout
        val tvOutput = popupView.findViewById<TextView>(R.id.tvOutput)
        val btnClose = popupView.findViewById<Button>(R.id.btnClose)

        // Set the output text
        tvOutput.text = "Output is $output"

        // Create the dialog builder
        val builder = AlertDialog.Builder(this)
        builder.setView(popupView)

        // Create and show the dialog
        val dialog = builder.create()
        dialog.show()

        // Handle button click
        btnClose.setOnClickListener {
            dialog.dismiss() // Close the dialog when the button is clicked
        }
    }
}

Running a different type of variables

I need to run predict with the variabels as follows:
gender : Integer, age : Integer, weight : Integer, height : Interger, hours : Integer, years : Integer, bmi : float

How can I do that? Do I need to add more inputs variabel?

class CalculatorActivity : AppCompatActivity() {

    private lateinit var binding: ActivityCalculatorBinding
    private var isGenderSelected = false
    private var previousSelectedPosition = 0
    private var selectedGenderValue = 0 // 0 for Female, 1 for Male

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        binding = ActivityCalculatorBinding.inflate(layoutInflater)
        setContentView(binding.root)

        val textInputEditTextAge = binding.tedAge
        val placeholderAge = "... Years"
        textInputEditTextAge.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextAge.hint = placeholderAge

        val textInputEditTextWeight = binding.tedWeight
        val placeholderWeight = "... KG"
        textInputEditTextWeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextWeight.hint = placeholderWeight

        val textInputEditTextHeight = binding.tedHeight
        val placeholderHeight = "... CM"
        textInputEditTextHeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHeight.hint = placeholderHeight

        val textInputEditTextHours = binding.tedHours
        val placeholderHours = "... Hours"
        textInputEditTextHours.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHours.hint = placeholderHours

        val textInputEditTextYears = binding.tedYears
        val placeholderYears = "... Years"
        textInputEditTextYears.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextYears.hint = placeholderYears

        val gender = resources.getStringArray(R.array.Gender)

        val spinner = binding.genderSpinner
        if (spinner != null) {
            val adapter = ArrayAdapter(this,
                android.R.layout.simple_spinner_dropdown_item, gender)
            spinner.adapter = adapter

            spinner.onItemSelectedListener = object :
                AdapterView.OnItemSelectedListener {
                override fun onItemSelected(parent: AdapterView<*>, view: View, position: Int, id: Long) {
                    if (isGenderSelected) {
                        // A gender other than "-- Select your Gender --" has already been selected
                        if (position == 0) {
                            spinner.setSelection(previousSelectedPosition) // Set the spinner to the previous selected position
                        } else {
                            previousSelectedPosition = position // Update the previous selected position
                            selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                        }
                    } else {
                        isGenderSelected = true
                        previousSelectedPosition = position // Set the initial selected position
                        selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                    }
                }

                override fun onNothingSelected(parent: AdapterView<*>) {
                    // write code to perform some action
                }
            }
        }

        binding.btnPredict.setOnClickListener {
            val gender = selectedGenderValue.toFloat()
            val age = textInputEditTextAge.text.toString().toFloat()
            val weight = textInputEditTextWeight.text.toString().toFloat()
            val height = textInputEditTextHeight.text.toString().toFloat()
            val hours = textInputEditTextHours.text.toString().toFloat()
            val years = textInputEditTextYears.text.toString().toFloat()
            val bmi = (weight?.div(((height?.div(100))?.times((height?.div(100)!!))!!))).toString().toFloat()
            val inputs = floatArrayOf(gender, age, weight, height, hours, years, bmi)
            if (inputs != null) {
                val ortEnvironment = OrtEnvironment.getEnvironment()
                val ortSession = createORTSession(ortEnvironment)
                val output = runPrediction(inputs, ortSession, ortEnvironment)
                showOutputPopup(output)
            } else {
                Toast.makeText(this, "Please fill in all the inputs", Toast.LENGTH_LONG).show()
            }
        }
    }

    private fun createORTSession( ortEnvironment: OrtEnvironment) : OrtSession {
        val modelBytes = resources.openRawResource( R.raw.model1 ).readBytes()
        return ortEnvironment.createSession( modelBytes )
    }

    private fun runPrediction(input : FloatArray, ortSession: OrtSession , ortEnvironment: OrtEnvironment ) : Long {
        // Get the name of the input node
        val inputName = ortSession.inputNames?.iterator()?.next()
        // Make a FloatBuffer of the inputs
        val floatBufferInputs = FloatBuffer.wrap(input)
        // Create input tensor with floatBufferInputs of shape ( 1 , 1 )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, floatBufferInputs, longArrayOf(1, 7))
        // Run the model
        val results = ortSession.run( mapOf( inputName to inputTensor ) )
        // Fetch and return the results
        val output = results[0].value as LongArray
        return output[0]
    }

    fun showOutputPopup(output: Long) {
        // Inflate the custom layout for the popup
        val inflater = layoutInflater
        val popupView = inflater.inflate(R.layout.popup_output, null)

        // Find views within the custom layout
        val tvOutput = popupView.findViewById<TextView>(R.id.tvOutput)
        val btnClose = popupView.findViewById<Button>(R.id.btnClose)

        // Set the output text
        tvOutput.text = "Output is $output"

        // Create the dialog builder
        val builder = AlertDialog.Builder(this)
        builder.setView(popupView)

        // Create and show the dialog
        val dialog = builder.create()
        dialog.show()

        // Handle button click
        btnClose.setOnClickListener {
            dialog.dismiss() // Close the dialog when the button is clicked
        }
    }
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.