しばやん雑記

Azure とメイドさんが大好きなフリーランスのプログラマーのブログ

Microsoft.Extensions.AI 向けに ONNX Runtime Generative AI を使うクラスを書いてみた

前回試してみた Microsoft.Extensions.AI (MEAI) は OpenAI や Azure AI Inference 向けにはライブラリが用意されているので、非常に簡単に Generative AI の機能を利用出来るようになっています。抽象化されたインターフェースにより、OpenAI や Azure AI Inference といった違いを吸収しているのは大きなメリットです。

MEAI は抽象化されたインターフェースが独立した NuGet パッケージで提供されているので、それを利用して実装すると他の API 向けにも実装できます。

以下のように ASP.NET Core などと同様に Abstractions パッケージが提供されています。

公式ブログや Abstractions パッケージの README では固定値を返すサンプル実装が紹介されていますが、今回は自分がこれまで弄っていた ONNX Runtime Generative AI 向けのクライアントを実装してみました。

ONNX Runtime Generative AI を使った SLM の利用については以前書いたので、これをベースに MEAI 向けにクライアントを実装するという流れです。

MEAI は IChatClient というインターフェースを実装すれば良いので、それに合うように ONNX Runtime Generative AI を呼び出すというシンプルな実装です。ONNX Runtime Generative AI を使った場合でもストリーミングは対応できるので実装してみました。

実装した OnnxRuntimeChatClient の全体は以下のようになります。まあまあの雑実装になっていますが、必要な機能は一通り入れているつもりです。

using System.Text;

using Microsoft.Extensions.AI;
using Microsoft.ML.OnnxRuntimeGenAI;

public class OnnxRuntimeChatClient : IChatClient
{
    public OnnxRuntimeChatClient(string modelPath)
    {
        _model = new Model(modelPath);
        _tokenizer = new Tokenizer(_model);
    }

    private readonly Model _model;
    private readonly Tokenizer _tokenizer;

    public void Dispose()
    {
        _tokenizer.Dispose();
        _model.Dispose();
    }

    public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = new CancellationToken())
    {
        var sequences = _tokenizer.Encode(BuildPrompt(chatMessages));

        var generatorParams = new GeneratorParams(_model);

        if (options?.MaxOutputTokens is not null)
        {
            generatorParams.SetSearchOption("max_length", options.MaxOutputTokens.Value);
        }

        if (options?.Temperature is not null)
        {
            generatorParams.SetSearchOption("temperature", options.Temperature.Value);
        }

        if (options?.TopP is not null)
        {
            generatorParams.SetSearchOption("top_p", options.TopP.Value);
        }

        generatorParams.SetInputSequences(sequences);
        generatorParams.TryGraphCaptureWithMaxBatchSize(1);

        var outputSequences = _model.Generate(generatorParams);
        var outputText = _tokenizer.Decode(outputSequences[0]);

        return new(new ChatMessage
        {
            Role = ChatRole.Assistant,
            Text = outputText
        });
    }

    public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = new CancellationToken())
    {
        var sequences = _tokenizer.Encode(BuildPrompt(chatMessages));

        var generatorParams = new GeneratorParams(_model);

        if (options?.MaxOutputTokens is not null)
        {
            generatorParams.SetSearchOption("max_length", options.MaxOutputTokens.Value);
        }

        if (options?.Temperature is not null)
        {
            generatorParams.SetSearchOption("temperature", options.Temperature.Value);
        }

        if (options?.TopP is not null)
        {
            generatorParams.SetSearchOption("top_p", options.TopP.Value);
        }

        generatorParams.SetInputSequences(sequences);
        generatorParams.TryGraphCaptureWithMaxBatchSize(1);

        using var tokenizerStream = _tokenizer.CreateStream();
        using var generator = new Generator(_model, generatorParams);

        while (!generator.IsDone())
        {
            generator.ComputeLogits();
            generator.GenerateNextToken();

            var outputText = tokenizerStream.Decode(generator.GetSequence(0)[^1]);

            yield return new StreamingChatCompletionUpdate
            {
                Role = ChatRole.Assistant,
                Text = outputText
            };
        }
    }

    public TService? GetService<TService>(object? key = null) where TService : class => this as TService;

    public ChatClientMetadata Metadata { get; }

    private string BuildPrompt(IList<ChatMessage> chatMessages)
    {
        var prompt = new StringBuilder();

        foreach (var chatMessage in chatMessages)
        {
            if (chatMessage.Role == ChatRole.System)
            {
                prompt.Append($"<|system|>{chatMessage.Text}<|end|>");
            }
            else if (chatMessage.Role == ChatRole.User)
            {
                prompt.Append($"<|user|>{chatMessage.Text}<|end|>");
            }
            else if (chatMessage.Role == ChatRole.Assistant)
            {
                prompt.Append($"<|assistant|>{chatMessage.Text}<|end|>");
            }
        }

        prompt.Append("<|assistant|>");

        return prompt.ToString();
    }
}

実装としては難しくないので説明はしませんが、基本は CompleteAsyncCompleteStreamingAsync を実装するだけで動きます。ONNX Runtime は非同期メソッドが用意されていないので、この実装でも使っていませんが IAsyncEnumerable<T> の実装方法が分からなかったので async を付けて対応しています。

今回は動作確認のためにリリースされたばかりの Phi-3.5 mini を使ってみます。ONNX バージョンが出たのでダウンロードするだけで簡単に試せるようになりました。

このクラスを利用して推論を行うサンプルコードは以下のようにシンプルなものになります。非同期イテレーターに対応しているので、使い勝手も良いですね。

using Microsoft.Extensions.AI;

var client = new OnnxRuntimeChatClient(@".\Phi-3.5-mini-instruct-onnx\cpu_and_mobile\cpu-int4-awq-block-128-acc-level-4");

await foreach (var update in client.CompleteStreamingAsync("Microsoft について簡単に説明してください"))
{
    Console.Write(update);
}

このコードを実行してみるとストリーミングで Phi-3.5 mini の推論結果が表示されていきます。結果の内容は例によって SLM っぽく適当な感じですが、OpenAI の時と同じコードで SLM を利用できています。

正直なところ公式で ONNX Runtime Generative AI に対応したライブラリがリリースされるのがベストですが、抽象化されたインターフェースが提供されているので独自のライブラリを簡単に実装できるのはかなり大きなメリットだと思います。