Building a Machine Learning API with FastAPI: Sepsis Prediction

Bright Eshun
12 min readJun 10, 2023

--

1. Introduction

In the field of healthcare, accurate and timely prediction of medical conditions can be crucial for patient outcomes. One such condition is sepsis, a life-threatening response to infection. In this article, we will explore how to build a Sepsis Prediction API using the FastAPI framework. FastAPI provides a high-performance and easy-to-use platform for developing robust APIs with Python.

1.1 Introduction to FastAPI

FastAPI is a modern and high-performance web framework for building APIs with Python. It offers a user-friendly development experience and scalability, making it a popular choice among developers. With FastAPI, developers can harness the power of Python to create robust APIs that handle complex logic, process large amounts of data, and provide real-time responses. It combines the strengths of popular frameworks like Flask and Django while introducing its own unique advantages. Built on top of the Starlette asynchronous web framework, FastAPI can efficiently handle concurrent requests and high traffic loads.

1.2 Why FastAPI

  • High Performance: FastAPI utilizes asynchronous capabilities and type hints for efficient web framework performance.
  • Easy to Use: FastAPI’s declarative approach and auto-generated documentation simplify API development and maintenance.
  • Modern Features: FastAPI supports dependency injection, OAuth2, request validation, and WebSocket for modern web development.
  • Type Safety: FastAPI’s type hints ensure data validation and serialization, improving code robustness and readability.
  • Compatibility: FastAPI seamlessly integrates with popular Python libraries and frameworks like SQLAlchemy and Pydantic.
  • Scalability: FastAPI handles high loads and concurrent requests through its asynchronous architecture.
  • Active Community: FastAPI has an active community and comprehensive documentation for support and learning.

2. Preparing the Model

2. 1 Overview of the existing Sepsis prediction model

The prediction model integrated into this API is a classification model specifically designed for sepsis prediction. It utilizes 23 input features and provides a binary prediction for sepsis status, categorizing it as either 1 (indicating a positive prediction for sepsis) or 0 (representing a negative prediction for sepsis).

2.2 Converting the model, pipeline and other features into pickle files

The classification model, column transformer, and other relevant properties from the model building process were saved as pickle files. This allows for persistence, as the model can be easily saved and loaded at any time, preserving its state and parameters. The platform-independent nature of pickle files ensures portability, enabling the model to be shared and deployed across different operating systems and environments. Integration is seamless, as the pickle files can be easily incorporated into various applications or frameworks.

To save the model the code below was ran in the notebook:

# Save the model and the columntransformer
# Navigate to the desired directory
directory = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'src', 'assets', 'ml_components'))
# Save the model as a pickle file
filename = os.path.join(directory, 'model-1.pkl')
pickle.dump(best_lgr, open(filename, 'wb'))

To save the Columntransformer and other components the code below was run in the notebook.

# Save columntrasnformer
filename2 = os.path.join(directory, 'preprocessor.pkl')
pickle.dump(full_pipeline, open(filename2, 'wb') )

Add the pickle files to your project folder.

2.3 Exporting model dependencies and requirements.

The model was built using specific dependencies and requirements and hence exporting them is essential for successful deployment. It captures the necessary libraries, packages, and configurations needed to run the model. Documentation ensures reproducibility and allows others to recreate the required environment. By exporting these dependencies, the model can be integrated into different environments without compatibility issues.

To export model’s dependencies into ‘requirements.txt’ file run the following code in the notebook:

pip freeze > requirements.txt

Save the requirement.txt file in the project’s base folder.

3. Setting Up the Project

Before diving into the development process, it’s essential to set up the project environment. We’ll need to install FastAPI and any additional dependencies required for our API. We can use virtual environments to isolate the project dependencies and ensure a clean development environment.

3.1 Creating Environment

Creating a virtual environment is essential for project development as it provides dependency isolation, reproducibility, a clean development environment, and facilitates collaboration. It ensures that project dependencies are isolated, allows for specifying specific package versions, maintains a clean development environment, enables consistent setups across team members, promotes project migration and deployment, and enhances project organization and dependency management.

To create a virtual environment
1. Open the command prompt.
2. Navigate to the project directory using the `cd` command.

Create a virtual environment in Windows using the following command:


python -m venv myenv

Activate the virtual environment in Windows using the following command:

myenv\Scripts\activate

Create a virtual environment in Linux using the following command:

python3 -m venv myenv

Activate the virtual environment in Linux using the following command:

source myenv/bin/activate

3.2 Installing dependencies and libraries

After creating and activating the environment, the next step is to install the necessary dependencies for your project. Dependencies are additional packages or libraries that your project relies on to function properly. Installing dependencies ensures that all the required components are available and accessible. To install dependencies, you can use a package manager such as pip.

To install dependencies using pip, you can use the following command:

pip install -r requirements.txt

By installing dependencies, you ensure that your project has access to the required functionality and libraries. It’s important to keep track of the dependencies and their versions to maintain consistency and avoid compatibility issues.

3. Building the API

3.1 Designing the API

To begin building our Sepsis Prediction API, we need to design the endpoints and data models that will be exposed. The API will include endpoints for Root, health checks, individual predictions, model information, batch prediction and data upload. We’ll utilize the power of FastAPI’s decorators and data validation with Pydantic to define the endpoints and models.

3.1.2 API Endpoints

Root: This Endpoint display a welcome message-” Welcome to the Sepsis API...”.

Health Check Endpoint: A crucial aspect of any API is a health check endpoint that allows us to monitor its status. We’ll implement a simple endpoint that returns the API’s status, ensuring it is running correctly.

Model Information Endpoint: This endpoint gives information about the model. It returns the model’s name, parameters and both the categorical and numerical features used in training.

Individual Prediction Endpoint: For individual predictions, we’ll create an endpoint that accepts the required input parameters and returns the prediction for a single data record. This endpoint will be useful when users want to make predictions on a single data point.

Batch Prediction Endpoint: To make predictions on a batch of input data, we’ll create an endpoint that accepts a list of input records. This endpoint will process the data using our trained ML model and return the predictions for each input.

Data Upload Endpoint: In real-world scenarios, users may want to upload their own data for prediction. We’ll implement an endpoint that allows users to upload data files or data in JSON format or CSV format. The API will process the data and return predictions using our ML model.

3.2 Integrating Model Into API

To make accurate sepsis predictions, we need to integrate a trained ML model into our API. We’ll explore how to load the model and use it to make predictions. Additionally, we’ll discuss techniques for optimizing the model loading process, such as caching the loaded model to avoid unnecessary overhead.

3.2.1 Loading pickled Model and Columntransformer

We create a function load_pickle that takes a filename as input. It opens the file in binary mode, reads its contents using pickle, and returns the loaded contents. The function is used to load the trained model, pipeline, and other components pickled files

import pickle 

def load_pickle(filename):
with open(filename, 'rb') as file:
contents = pickle.load(file)
return contents


model_path = os.path.join(DIRPATH, '..', 'assets', 'ml_components', 'model-1.pkl')
transformer_path = os.path.join(DIRPATH, '..', 'assets', 'ml_components', 'preprocessor.pkl')
other_components_path = os.path.join(DIRPATH, '..', 'assets', 'ml_components', 'other_compents.pkl')


# Load the trained model, pipeline, and other properties
model = load_pickle(model_path)
transformer = load_pickle(transformer_path)
other_components = load_pickle(other_components_path)

3.2.2 Handling Data and Preprocessing

Before making predictions, we often need to handle data preprocessing. In the case of sepsis prediction, we may need to perform certain preprocessing steps on the input data to ensure it is in the proper format for our ML model. This could involve tasks such as handling missing values, scaling features, or encoding categorical variables.

The following steps were taken to handle and preprocess data:

i. Collecting and validation inputs

In the Sepsis Prediction API, data collection and validation play a crucial role in ensuring the accuracy and reliability of predictions. Depending on the endpoint being used, different methods are employed to collect and validate the input data.

Individual Prediction Endpoint: For individual predictions, the input data is directly passed into the endpoint as variables.

Batch Prediction Endpoint: In the case of batch predictions, a Pydantic BaseModel validates structured batch input for consistency and accuracy.

Data Upload Endpoint: The data upload endpoint allows Users to upload JSON or CSV files for prediction, which are validated for format and structure to ensure compatibility and handle potential errors.

ii. Feature Engineering

After collecting and validating the data, the next step is to create the features that were used to train the model. This involves transforming and manipulating the raw data to extract or create the same input variables for the prediction model.

def feature_engineering(data):
data['Insurance'] = data['Insurance'].astype(int).astype(str) # run function to create new features
# create features
data['All-Product'] = data['Blood Work Result-4'] * data['Blood Work Result-1']* data['Blood Work Result-2']* data['Blood Work Result-3'] * data['Plasma Glucose']* data['Blood Pressure'] * data['Age']* data['Body Mass Index'] # Multiply all numerical features

all_labels =['{0}-{1}'.format(i, i+500000000000) for i in range(0, round(2714705253292.0312),500000000000)]
data['All-Product_range'] = pd.cut(data['All-Product'], bins=(range(0, 3500000000000, 500000000000)), right=False, labels=all_labels)

age_labels =['{0}-{1}'.format(i, i+20) for i in range(0, 83,20)]
data['Age Group'] = pd.cut(data['Age'], bins=(range(0, 120, 20)), right=False, labels=age_labels) # create categorical features for age

labels =['{0}-{1}'.format(i, i+30) for i in range(0, round(67.1),30)]
data['BMI_range'] = pd.cut(data['Body Mass Index'], bins=(range(0, 120, 30)), right=False, labels=labels) # create categorical features for bodey mass index

iii. Handling JSON and CSV files

This function handles csv files

The function below is used to process JSON and CSV files into formats that can be converted into dataframes.

def process_json_csv(contents, file_type, valid_formats):

# Read the file contents as a byte string
contents = contents.decode() # Decode the byte string to a regular string
new_columns = return_columns() # return new_columns
# Process the uploaded file
if file_type == valid_formats[0]:
data = pd.read_csv(StringIO(contents)) # read csv files
elif file_type == valid_formats[1]:
data = pd.read_json(contents) # read json file
data = data.drop(columns=['ID']) # drop ID column
dict_new_old_cols = dict(zip(data.columns, new_columns)) # get dict of new and old cols
data = data.rename(columns=dict_new_old_cols) # rename colums to appropriate columns
return data

iv. Making a Prediction

Now, we will have to make a prediction using the API. The make_prediction function was created to achieve this purpose.

The `make_prediction` function takes in data, transformer, and model as input. It performs various data preprocessing steps, including renaming columns, feature engineering, and transforming the data using the provided transformer. It then combines categorical and numerical features, makes a prediction using the model, and returns the predicted label and the maximum probability among the predicted probabilities.

def make_prediction(data, transformer, model):
new_columns = return_columns()
dict_new_old_cols = dict(zip(data.columns, new_columns))
data = data.rename(columns=dict_new_old_cols)
feature_engineering(data) # create new features
transformed_data = transformer.transform(data) # transform the data using the transformer
combine_cats_nums(transformed_data, transformer)# create a dataframe from the transformed data
# make prediction
label = model.predict(transformed_data) # make a prediction
probs = model.predict_proba(transformed_data)
return label, probs.max()

v. Formatting outputs

The outputs when a post request is made to API need to be formatted in a proper, easily readable and accessible format. The function output_batch takes in data1 and labels as input. It then converts the data1 and labels into dictionaries and iterates over the indices of the labels dictionary to append input and output values to the return_list.

# formats response 
def output_batch(data1, labels):
data_labels = pd.DataFrame(labels, columns=['Predicted Label']) # convert label into a dataframe
data_labels['Predicted Label'] = data_labels.apply(process_label, axis=1) # change label to understanding strings
results_list = [] # create an empty lits
x = data1.to_dict('index') # convert datafram into dictionary
y = data_labels.to_dict('index') # convert datafram into dictionary
for i in range(len(y)):
results_list.append({i:{'inputs': x[i], 'output':y[i]}}) # append input and labels

final_dict = {'results': results_list}
return final_dict

3.2.3 Building Endpoints

Root-/

The ‘/’ URL returns an html response which contains the message, ‘Welcome to the Sepsis API’.

@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse("index.html", {'request': request})

Health Check Endpoint-/health

The /health endpoint returns a message indicating the health condition of the API.

@app.get("/health")
def check_health():
return {"status": "ok"}

Model Information Endpoint-/model-info

The /model-info endpoint returns some information on the model used.

@app.post('/model-info')
async def model_info():
model_name = model.__class__.__name__
model_params = model.get_params()
features = properties['train features']
print(features)
model_information = {'model info': {
'model name ': model_name,
'model parameters': model_params,
'train feature': features}
}
return model_information

Individual Prediction Endpoint

The /predict endpoint receives inputs and makes a single prediction. It returns both input and output data.

# Prediction endpoint
@app.post('/predict')
async def predict(plasma_glucose: float, blood_work_result_1: float,
blood_pressure: float, blood_work_result_2: float,
blood_work_result_3: float, body_mass_index: float,
blood_work_result_4: float, age: int, insurance: bool):

# Create a dataframe from inputs
data = pd.DataFrame([[plasma_glucose,blood_work_result_1,blood_pressure,
blood_work_result_2,blood_work_result_3,body_mass_index,
blood_work_result_4, age,insurance]], columns=return_columns())

# data_copy = data.copy() # Create a copy of the dataframe
labels, prob = make_prediction(data, transformer, model) # Get the labels
response = output_batch(data, labels) # output results
return response

Refer to project on GitHub for the Batch Prediction Endpoint-/predict-batch and Data Upload Endpoint-/upload-data endpoints.

4. Testing the API

It is crucial to test the API to ensure its functionality, reliability, and adherence to requirements, identifying and addressing any potential issues or bugs. Testing helps to validate the API’s behavior, improve its performance, and enhance the overall user experience, ensuring it meets the desired standards and objectives.

4.1 Running the API locally

To run the API locally we need to start the FastAPI development server by running the file that contains our API definition which in this case is the app.py file.

To do this we can run the following command on the terminal:

python 'path/app.py'

Make sure the server is up and running.

4.2 Testing the API endpoints

To test the API endpoints, you can follow these steps:

  1. Open a web browser and got to http://127.0.0.1:8000/docs to send HTTP requests to the API endpoints.
  2. Test the health check endpoint by sending a GET request to the health URL or /health. Verify that the response indicates that the API is running correctly.
  3. Test the model information endpoint by sending a get request to the model information URL-/model-info
  4. Test the individual prediction endpoint by sending a POST request to the prediction URL-predict, providing the required input data as variables in the request body. Ensure that the input data is valid and in the expected format. Capture the response, which will include the predicted label and associated probabilities.
  5. Test the batch prediction endpoint by sending a POST request to the batch prediction URL-/predict-batch, providing a batch of input data in the request body. The input data should be structured according to the defined Pydantic BaseModel and match the expected format. Capture the response, which will include the predicted labels and associated probabilities for each input record.
  6. Test the data upload endpoint by sending a POST request to the data upload URL-/upload-data, attaching a file (JSON or CSV) containing the input data. Validate that the uploaded file is in the correct format and matches the expected schema. Capture the response, which will include the predicted labels and associated probabilities for the uploaded data.
  7. Analyze the test results to ensure the predictions and probabilities align with the expected outcomes. Make any necessary adjustments to the API or the input data if required.
  8. Repeat the testing process with various input scenarios to thoroughly validate the functionality and performance of the Sepsis Prediction API.

4.3 Handling Errors and Exceptions

While building an API, it’s crucial to handle errors and exceptions gracefully. FastAPI provides robust error handling capabilities that allow developers to catch specific exceptions, such as validation errors or model prediction errors, and provide informative error messages to the API users.

For the following code was to return a custom error message in case a wrong file type is uploaded to the data upload endpoint:

    valid_formats = ['text/csv', 'application/json']
if file_type not in valid_formats:
return JSONResponse(content={"error": f"Invalid file format. Must be one of: {', '.join(valid_formats)}"})

4.4 Optimizing model loading process

It is essential to cache the model loading process using a caching module. Caching the model is crucial for improving performance, especially when loading the model is time-consuming. This approach reduces overhead and enhances overall efficiency, particularly in scenarios where the model is accessed frequently. The decorator- lru_cache from the functools module allows the load_pickle function, responsible for loading a pickled object from a file, to cache its results. By doing so, subsequent calls with the same filename retrieve the cached result, eliminating redundant file I/O operations.

from functools import lru_cache

@lru_cache(maxsize=100, )
def load_pickle(filename):
with open(filename, 'rb') as file:
contents = pickle.load(file)
return contents

Screenshot of the API App

Conclusions and Resources

Conclusion

In conclusion, leveraging FastAPI to build APIs provides organizations with a powerful and efficient solution for their application needs. FastAPI’s high performance, easy-to-use nature, and modern features make it an ideal framework for developing robust and scalable APIs.

Throughout this article, we explored various aspects of building the Sepsis Prediction API with FastAPI, including project setup, designing endpoints, integrating machine learning models, handling data, and implementing error handling. By following the discussed steps and best practices, developers can create a reliable and accurate Sepsis Prediction API.

Resources

  1. FastAPI Documentation
  2. FastAPI Tutorial
  3. FastAPI Cookbook

You can find the complete code for the project on my GitHub, where it is publicly available. I encourage and welcome feedback and comments from the community to improve and enhance the project. Feel free to explore the code, provide suggestions, and contribute to its development. Your input is highly valued and appreciated.

--

--

Bright Eshun

Multi-dimensional data scientist, programmer, and cloud computing enthusiast with a talent for crafting engaging narratives. Follow for innovative insights.