TorchServe with GRPC

Created: June 4, 2022 11:27 AM

This page documents a template for serving torch models and communicating with them with GRPC. I believe the content available online mostly focuses on using HTTP for communicating with the served models. Therefore, this is a good place for someone who has a preference for faster inference with the served models using GRPC.

This page does not explain the theory behind TorchServe or GRPC. Instead, it sets the steps for serving the models and implementing a grpc client.

Directory structure

project_root_folder
|__torchserve_grpc
	|__ model_store
		|__digit_model.mar
	|__model_weights
		|__digitcnn_state_dict.pth
	|__images
		|__test.png
	|__handler.py
	|__model.py
	|__inference.proto
	|__management.proto
	|__inference_pb2.py
	|__inference_pb2_grpc.py
	|__management_pb2.py
	|__management_pb2_grpc.py
  |__grpc_client.py

Model Definition file (model.py)

import torch, torchvision

class DigitCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = self._model_prep()

    def _model_prep(self):  
        model = torchvision.models.mobilenet_v3_small(pretrained=True, progress=True, )
        classification_layer = torch.nn.Sequential(
            torch.nn.Linear(576, out_features=1024),
            torch.nn.Hardswish(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024, 10),
            torch.nn.Softmax()
        )
        model.classifier = classification_layer

        for param in model.features[:11].parameters():
            param.requires_grad = False
        return model

    def forward(self, x):
        return self.model(x)

Custom Handler (handler.py)

from ts.torch_handler.vision_handler import VisionHandler
import torch
from PIL import Image
from torchvision import transforms
import logging
import io
import base64

class CustomHandler(VisionHandler):
    def __init__(self):
        super(CustomHandler, self).__init__()
        self.image_processing = transforms.Compose([
            transforms.Resize((96, 96)),
            transforms.ToTensor()])
    
    def preprocess(self, data):
        images = []
        for row in data:
            # Compat layer: normally the envelope should just return the data
            # directly, but older versions of Torchserve didn't have envelope.
            image = row.get("data") or row.get("body")
            if isinstance(image, str):
                # if the image is a string of bytesarray.
                image = base64.b64decode(image)

            # If the image is sent as bytesarray
            if isinstance(image, (bytearray, bytes)):
                image = Image.open(io.BytesIO(image)).convert("RGB")
                image = self.image_processing(image)
                logging.info(f"image shape after preprocess: {image.shape}")
            else:
                # if the image is a list
                image = torch.FloatTensor(image)
            
            images.append(image)

        return torch.stack(images).to(self.device)
    
    def postprocess(self, data):
        logging.info("Inside Post Process")
        logging.info(f"Outputs: {data}")
        predictions = torch.argmax(data, axis=1) + 1
        logging.info(predictions.tolist())
        return predictions.tolist()

Archiving the model

torch-model-archiver --model-name digit-model --export-path torchserve_grpc/model_store --version 1.0 --model-file torchserve_grpc/model.py --serialized-file torchserve_grpc/model_weights/digitcnn_state_dict.pth --handler torchserve_grpc/handler.py --force

Serving the model

torchserve --start --model-store torchserve_grpc/model_store --models digitmodel=digit-model.mar --no-config-snapshot

Protocol buffer files

serve/inference.proto at master · pytorch/serve

inference.proto

syntax = "proto3";
package org.pytorch.serve.grpc.inference;

import "google/protobuf/empty.proto";

option java_multiple_files = true;

message PredictionsRequest {
    // Name of model.
    string model_name = 1; //required

    // Version of model to run prediction on.
    string model_version = 2; //optional

    // input data for model prediction
    map<string, bytes> input = 3; //required
}

message PredictionResponse {
    // TorchServe health
    bytes prediction = 1;
}

message TorchServeHealthResponse {
    // TorchServe health
    string health = 1;
}

service InferenceAPIsService {
    rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {}

    // Predictions entry point to get inference using default model version.
    rpc Predictions(PredictionsRequest) returns (PredictionResponse) {}
}

serve/management.proto at master · pytorch/serve

management.proto

syntax = "proto3";

package org.pytorch.serve.grpc.management;

option java_multiple_files = true;

message ManagementResponse {
    // Response string of different management API calls.
    string msg = 1;
}

message DescribeModelRequest {
    // Name of model to describe.
    string model_name = 1; //required
    // Version of model to describe.
    string model_version = 2; //optional
    // Customized metadata
    bool customized = 3; //optional
}

message ListModelsRequest {
    // Use this parameter to specify the maximum number of items to return. When this value is present, TorchServe does not return more than the specified number of items, but it might return fewer. This value is optional. If you include a value, it must be between 1 and 1000, inclusive. If you do not include a value, it defaults to 100.
    int32 limit = 1; //optional

    // The token to retrieve the next set of results. TorchServe provides the token when the response from a previous call has more results than the maximum page size.
    int32 next_page_token = 2; //optional
}

message RegisterModelRequest {
    // Inference batch size, default: 1.
    int32 batch_size = 1; //optional

    // Inference handler entry-point. This value will override handler in MANIFEST.json if present.
    string handler = 2; //optional

    // Number of initial workers, default: 0.
    int32 initial_workers = 3; //optional

    // Maximum delay for batch aggregation, default: 100.
    int32 max_batch_delay = 4; //optional

    // Name of model. This value will override modelName in MANIFEST.json if present.
    string model_name = 5; //optional

    // Maximum time, in seconds, the TorchServe waits for a response from the model inference code, default: 120.
    int32 response_timeout = 6; //optional

    // Runtime for the model custom service code. This value will override runtime in MANIFEST.json if present.
    string runtime = 7; //optional

    // Decides whether creation of worker synchronous or not, default: false.
    bool synchronous = 8; //optional

    // Model archive download url, support local file or HTTP(s) protocol.
    string url = 9; //required

    // Decides whether S3 SSE KMS enabled or not, default: false.
    bool s3_sse_kms = 10; //optional
}

message ScaleWorkerRequest {

    // Name of model to scale workers.
    string model_name = 1; //required

    // Model version.
    string model_version = 2; //optional

    // Maximum number of worker processes.
    int32 max_worker = 3; //optional

    // Minimum number of worker processes.
    int32 min_worker = 4; //optional

    // Number of GPU worker processes to create.
    int32 number_gpu = 5; //optional

    // Decides whether the call is synchronous or not, default: false.
    bool synchronous = 6; //optional

    // Waiting up to the specified wait time if necessary for a worker to complete all pending requests. Use 0 to terminate backend worker process immediately. Use -1 for wait infinitely.
    int32 timeout = 7; //optional
}

message SetDefaultRequest {
    // Name of model whose default version needs to be updated.
    string model_name = 1; //required

    // Version of model to be set as default version for the model
    string model_version = 2; //required
}

message UnregisterModelRequest {
    // Name of model to unregister.
    string model_name = 1; //required

    // Name of model to unregister.
    string model_version = 2; //optional
}

service ManagementAPIsService {
    // Provides detailed information about the default version of a model.
    rpc DescribeModel(DescribeModelRequest) returns (ManagementResponse) {}

    // List registered models in TorchServe.
    rpc ListModels(ListModelsRequest) returns (ManagementResponse) {}

    // Register a new model in TorchServe.
    rpc RegisterModel(RegisterModelRequest) returns (ManagementResponse) {}

    // Configure number of workers for a default version of a model.This is a asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed.
    rpc ScaleWorker(ScaleWorkerRequest) returns (ManagementResponse) {}

    // Set default version of a model
    rpc SetDefault(SetDefaultRequest) returns (ManagementResponse) {}

    // Unregister the default version of a model from TorchServe if it is the only version available.This is a asynchronous call by default. Caller can call listModels to confirm model is unregistered
    rpc UnregisterModel(UnregisterModelRequest) returns (ManagementResponse) {}
}

python files can be generated via protocol buffer compiler

python -m grpc_tools.protoc --proto_path=/Users/mo/Projects/Blog/torchserve_grpc/ --python_out=torchserve_grpc --grpc_python_out=torchserve_grpc torchserve_grpc/management.proto torchserve_grpc/inference.proto 

The above command should generate 4 python files:

  • inference_pb2
  • inference_pb2_grpc
  • management_pb2
  • management_pb2_grpc

The above files are needed to construct the objects and call the services implemented by the grpc server. The grpc server is already implemented and running with torchserve.

Writing a gRPC client

  1. We need a stub to make the requests to the server
  2. The gRPC server is already implement by torchserve and starts along with torchserve
  3. Based on the inference.proto and management.proto files we can figure out which functions can be called along with their required parameters

We’ll try to construct the gRPC client that is shared by the official documentation. The client file can be found here.

First, we import the generated files in the grpc_client.py

import grpc
import inference_pb2
import inference_pb2_grpc
import management_pb2
import management_pb2_grpc
import json

By default ports 7071 and 7071 are exposed for inference and management, respectively. Therefore we may need to two stubs - one for inference and another for management requests.

def get_inference_stub():
    channel = grpc.insecure_channel('localhost:7070')
    stub = inference_pb2_grpc.InferenceAPIsServiceStub(channel)
    return stub

def get_management_stub():
    channel = grpc.insecure_channel('localhost:7071')
    stub = management_pb2_grpc.ManagementAPIsServiceStub(channel)
    return stub

Note that you can replace localhost with the private ip (or another remotely accessible ip) of a remote server.

You may also notice above that while making the above stub, we use the gRPC modules generated by the protocol buffer compiler. For example, the management stub can now make the rpc calls that have been defined by the management.proto files.

Now, we can write any function that makes a request to a service exposed by the torchserve. The exposed services are given by the protocol buffer files (.proto) above. Let’s write a function that will list all the models that are served.

We can notice that the management.proto file have a ListModels RPC which takes in a parameter ListModelsRequest message also defined in the same file. The message object can be created via the management_pb2 file and the rpc can be made with the stub.

def list_models(management_stub):
    list_model_request_object = management_pb2.ListModelsRequest(limit=10)
    return management_stub.ListModels(list_model_request_object)

Testing this:

if __name__ == '__main__':

    inference_stub = get_inference_stub()
    management_stub = get_management_stub()

    output = list_models(management_stub)
    print(f"output: {output}")

    message = json.loads(output.msg)
    print(f"message: {message}")

Output:

output: msg: "{\n  \"models\": [\n    {\n      \"modelName\": \"digitmodel\",\n      \"modelUrl\": \"digit-model.mar\"\n    }\n  ]\n}"

message: {'models': [{'modelName': 'digitmodel', 'modelUrl': 'digit-model.mar'}]}

Finally, if we want to make a prediction, the request message and rpc are defined in the inference.proto file. We’d like to make PredictionsRequest object with inference_pb2 which needs the model name, model_version and bytes input.


def make_prediction(inference_stub, image_path):
    with open(image_path, "rb") as f:
        image_bytes = f.read()
    input_data = {"data": image_bytes}
    prediction_request = inference_pb2.PredictionsRequest(model_name="digitmodel", input=input_data)
    
    prediction = inference_stub.Predictions(prediction_request)
    return prediction

Testing this:

if __name__ == '__main__':
    inference_stub = get_inference_stub()
    management_stub = get_management_stub()

    prediction = make_prediction(inference_stub, "torchserve_grpc/images/test.png")

    print(prediction)

Output:

prediction: "8"

If you have any questions about this page, please feel free to email me. If this helped you, please consider buying me a coffee

Written on April 20, 2023