diff --git a/docs/modules/ROOT/pages/tutorials/llama3-tool-calling.adoc b/docs/modules/ROOT/pages/tutorials/llama3-tool-calling.adoc index ea8948dd..a84016f9 100644 --- a/docs/modules/ROOT/pages/tutorials/llama3-tool-calling.adoc +++ b/docs/modules/ROOT/pages/tutorials/llama3-tool-calling.adoc @@ -1,30 +1,219 @@ = Llama 3 / 3.1 / 3.2 Tool Calling -:description: How kllama formats tool-calling prompts for the Llama 3 family and how it parses the model's responses, including which of Meta's two response formats to pick. +:description: End-to-end guide for adding Llama 3 tool calling to your own application — register tools, run the agent loop, and observe each round's prompt and assistant response. -This page describes how `kllama` formats tool-calling prompts for the Llama 3 family and how it parses the model's responses. It also explains the two response formats Meta has documented and which one to pick for which model. +This page is a getting-started guide for wiring Llama 3 / 3.1 / 3.2 tool calling into your own application. It covers the full path: load a GGUF, register custom tools, run the agent loop, and observe what the model sees and emits each round. The latter half explains the two prompt/response formats Meta documents and why the default works for Llama 3.2. [TIP] ==== For Llama 3.2 1B / 3B (and any Llama 3.x in 2025) leave the defaults alone. The default format is `Llama3ToolFormat.JSON`, which is what Llama 3.2 was fine-tuned on for custom tools. Switch to `Llama3ToolFormat.FUNCTION_TAG` only if you are running an older Llama 3.1 prompt that expects the tag-wrapped form. ==== -== Quick start +== Quick start: try it from the CLI + +Run the bundled demo against a Llama 3.x GGUF — useful as a sanity check before embedding the same code in your app. [source,bash] ---- -# Build -./gradlew :llm-apps:kllama-cli:shadowJar +./gradlew :llm-apps:kllama-cli:run --quiet \ + --args='-m /path/to/Llama-3.2-1B-Instruct-Q8_0.gguf --demo -s 256 -k 0.7 \ + "What files are in /tmp?"' +---- + +The demo registers two tools (`list_files`, `calculator`), prints the rendered prompt and tool schemas, and runs the agent loop until the model produces a final assistant message. Expect output like: + +---- +[Tools] (2) + - list_files: List files and directories in a local folder. ... + - calculator: Evaluate a mathematical expression. ... +[Prompt → Round 1] (1553 chars) +┌──────────────────────────────────────────────────────────────────────┐ +│ <|begin_of_text|><|start_header_id|>system<|end_header_id|> +│ ...full Llama 3 tool-calling system prompt with both function schemas... +│ <|eot_id|><|start_header_id|>user<|end_header_id|> +│ What files are in /tmp?<|eot_id|>... +└──────────────────────────────────────────────────────────────────────┘ +[Raw Assistant → Round 1] {"name": "list_files", "parameters": {"path": "/tmp"}} +[Tool Call] list_files({"path":"/tmp"}) +[Tool Result] list_files -> [dir] .ICE-unix ... and 4647 more entries +---- + +The agent loop then runs round 2, feeding the tool result back so the model can summarise. + +== Use it from your own Kotlin app + +The pieces you need live in three modules: + +* `llm-runtime-kllama` — `KLlamaJava.loadGGUF(path)` builds the runtime + tokenizer in one call (Java-friendly facade; works fine from Kotlin too). +* `llm-agent` — `ChatSession`, `AgentLoop`, `Tool`, `ToolRegistry`, `AgentListener`. +* `llm-core` — pulled in transitively. + +=== Step 1 — Add the dependency + +[source,kotlin] +---- +dependencies { + implementation("sk.ainet.transformers:llm-runtime-kllama:0.23.2") + implementation("sk.ainet.transformers:llm-agent:0.23.2") +} +---- + +The runtime needs the Java Vector API at launch: + +[source] +---- +--enable-preview --add-modules jdk.incubator.vector +---- + +=== Step 2 — Load the model + +[source,kotlin] +---- +import sk.ainet.apps.kllama.java.KLlamaJava +import java.nio.file.Path + +val session = KLlamaJava.loadGGUF(Path.of("models/Llama-3.2-1B-Instruct-Q8_0.gguf")) +// session.runtime : InferenceRuntime +// session.tokenizer: Tokenizer +// session is AutoCloseable — close it to release the Arena. +---- + +`KLlamaJava.loadGGUF` accepts Llama / Mistral GGUFs and bundles the loader, tokenizer, and runtime construction. For SafeTensors checkpoints use `loadSafeTensors(modelDir)`. + +=== Step 3 — Define your tool + +A tool is a `ToolDefinition` (name + JSON-Schema `parameters`) plus an `execute` function. + +[source,kotlin] +---- +import kotlinx.serialization.json.* +import sk.ainet.apps.kllama.chat.Tool +import sk.ainet.apps.kllama.chat.ToolDefinition + +class WeatherTool : Tool { + override val definition = ToolDefinition( + name = "get_weather", + description = "Get the current weather for a city.", + parameters = buildJsonObject { + put("type", "object") + putJsonObject("properties") { + putJsonObject("city") { + put("type", "string") + put("description", "City name, e.g. 'Bratislava'.") + } + } + putJsonArray("required") { add(JsonPrimitive("city")) } + } + ) + + override fun execute(arguments: JsonObject): String { + val city = arguments["city"]?.jsonPrimitive?.content + ?: return "Error: missing 'city'" + // Real call to your weather backend goes here. + return """{"city":"$city","tempC":22,"condition":"sunny"}""" + } +} +---- + +The schema is the contract the model sees in the system prompt — keep it tight, mark required fields, and make `description` something the model can actually act on. + +=== Step 4 — Wire `ChatSession` + `AgentLoop` -# Run the demo against a Llama 3.x GGUF (auto-detects the family) -java --enable-preview --add-modules jdk.incubator.vector \ - -jar llm-apps/kllama-cli/build/libs/kllama-all.jar \ - -m models/Llama-3.2-1B-Instruct-Q8_0.gguf \ - --demo --template=llama3 \ - -s 256 -k 0.0 \ - "What files are in /tmp?" +[source,kotlin] +---- +import sk.ainet.apps.kllama.chat.* + +val chat = ChatSession( + runtime = session.runtime, + tokenizer = session.tokenizer, + // family="llama" auto-resolves to Llama3ToolCallingSupport with the + // bare-JSON format Llama 3.2 was fine-tuned on. Override only if you + // know you need FUNCTION_TAG (see "The two formats" below). + metadata = ModelMetadata(family = "llama", architecture = "llama"), +) + +val tools = ToolRegistry().apply { + register(WeatherTool()) +} + +val loop = chat.createAgentLoop( + toolRegistry = tools, + maxTokens = 256, + temperature = 0.7f, +) + +val messages = mutableListOf( + ChatMessage( + role = ChatRole.SYSTEM, + content = "You are a helpful assistant with access to tools. " + + "Always call get_weather when asked about weather — never guess." + ), + ChatMessage(role = ChatRole.USER, content = "What's the weather in Bratislava?"), +) + +val finalAnswer = loop.runWithEncoder( + messages = messages, + encode = { chat.encode(it) }, +) +println(finalAnswer) ---- -The demo registers two tools (`list_files`, `calculator`) and runs the agent loop until the model produces a final assistant message. +The loop renders the chat template with your tools embedded, generates until EOS, parses the assistant's reply for a tool call, executes the tool, appends the result to `messages`, and re-runs — up to `AgentConfig.maxToolRounds` (default 5). + +=== Step 5 — Observe what the model sees and emits + +Pass an `AgentListener` to log prompts, raw responses, tool invocations, and results. This is the same listener `ToolCallingDemo` uses for the CLI output above. + +[source,kotlin] +---- +val listener = object : AgentListener { + override fun onToken(token: String) { print(token) } + override fun onAssistantMessage(text: String) { + println("\n[raw assistant] $text") + } + override fun onToolCalls(calls: List) { + for (c in calls) println("[tool call] ${c.name}(${c.arguments})") + } + override fun onToolResult(call: ToolCall, result: String) { + println("[tool result] ${call.name} -> $result") + } + override fun onToolCallValidationFailed(call: ToolCall, reason: String) { + println("[tool call invalid] ${call.name}: $reason") + } + override fun onComplete(finalResponse: String) {} +} + +loop.runWithEncoder(messages, encode = { chat.encode(it) }, listener = listener) +---- + +To see the *prompt* the model receives at the start of each round (not just the response), render the template yourself before calling the loop: + +[source,kotlin] +---- +val rendered = chat.chatTemplate.apply( + messages = messages, + tools = tools.definitions(), + addGenerationPrompt = true, +) +println("[prompt] (${rendered.length} chars)\n$rendered") +---- + +[NOTE] +==== +Llama 3.2 1B sometimes wraps its tool-call JSON in a markdown code fence (```` ``` ````) even though the system prompt asks for bare JSON. `Llama31ToolCallParserStrategy` peels one layer of fencing automatically, so both `{"name":"x", ...}` and ` ```{"name":"x", ...}``` ` parse the same way. +==== + +=== Verify it's working + +You should see exactly this sequence in your listener output for the weather example: + +. `onToken` fires repeatedly as the model generates `{"name": "get_weather", "parameters": {"city": "Bratislava"}}`. +. `onAssistantMessage` fires once with that full text. +. `onToolCalls` fires with `[ToolCall(name="get_weather", arguments={"city":"Bratislava"})]`. +. `onToolResult` fires with your stub's JSON response. +. The loop spins again — the model now sees the tool result in its context and produces a natural-language answer. +. `onComplete` fires with the final user-facing answer. + +If `onToolCalls` *never* fires and `onComplete` returns the raw JSON instead, the model emitted a call but the parser missed it — file an issue with the `[raw assistant]` text. The bare-JSON parser handles `<|python_tag|>` prefixes, code fences, and trailing prose, but novel surface forms slip through. == The two formats @@ -66,6 +255,7 @@ Parser (`Llama31ToolCallParserStrategy`) accepts: * The Meta-documented `"parameters"` key, or `"arguments"` (Hermes-style alias). * A leading `<|python_tag|>` marker (used by Llama 3.2's built-in tools; tolerated here too). +* A surrounding markdown code fence (```` ```json ```` / ```` ``` ````) — Llama 3.2 1B occasionally fences its JSON despite the system-prompt instruction. * Trailing prose after the JSON object (small models often append "I hope that helps!"). === `Llama3ToolFormat.FUNCTION_TAG` (Llama 3.1 legacy) diff --git a/gradle.properties b/gradle.properties index 0e59b640..be20174e 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ GROUP=sk.ainet.transformers -VERSION_NAME=0.23.1 +VERSION_NAME=0.23.2 POM_DESCRIPTION=SKaiNET-transformers diff --git a/llm-agent/api/jvm/llm-agent.api b/llm-agent/api/jvm/llm-agent.api index 5898fab5..edde6a76 100644 --- a/llm-agent/api/jvm/llm-agent.api +++ b/llm-agent/api/jvm/llm-agent.api @@ -45,7 +45,9 @@ public final class sk/ainet/apps/kllama/chat/AgentConfig { public abstract interface class sk/ainet/apps/kllama/chat/AgentListener { public fun onAssistantMessage (Ljava/lang/String;)V public fun onComplete (Ljava/lang/String;)V + public fun onThinking (Ljava/lang/String;)V public fun onToken (Ljava/lang/String;)V + public fun onToolCallValidationFailed (Lsk/ainet/apps/kllama/chat/ToolCall;Ljava/lang/String;)V public fun onToolCalls (Ljava/util/List;)V public fun onToolResult (Lsk/ainet/apps/kllama/chat/ToolCall;Ljava/lang/String;)V } @@ -53,7 +55,9 @@ public abstract interface class sk/ainet/apps/kllama/chat/AgentListener { public final class sk/ainet/apps/kllama/chat/AgentListener$DefaultImpls { public static fun onAssistantMessage (Lsk/ainet/apps/kllama/chat/AgentListener;Ljava/lang/String;)V public static fun onComplete (Lsk/ainet/apps/kllama/chat/AgentListener;Ljava/lang/String;)V + public static fun onThinking (Lsk/ainet/apps/kllama/chat/AgentListener;Ljava/lang/String;)V public static fun onToken (Lsk/ainet/apps/kllama/chat/AgentListener;Ljava/lang/String;)V + public static fun onToolCallValidationFailed (Lsk/ainet/apps/kllama/chat/AgentListener;Lsk/ainet/apps/kllama/chat/ToolCall;Ljava/lang/String;)V public static fun onToolCalls (Lsk/ainet/apps/kllama/chat/AgentListener;Ljava/util/List;)V public static fun onToolResult (Lsk/ainet/apps/kllama/chat/AgentListener;Lsk/ainet/apps/kllama/chat/ToolCall;Ljava/lang/String;)V } @@ -67,11 +71,57 @@ public final class sk/ainet/apps/kllama/chat/AgentLoop { public static synthetic fun runWithEncoder$default (Lsk/ainet/apps/kllama/chat/AgentLoop;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lsk/ainet/apps/kllama/chat/AgentListener;ILjava/lang/Object;)Ljava/lang/String; } +public final class sk/ainet/apps/kllama/chat/ApertusChatTemplate : sk/ainet/apps/kllama/chat/ChatTemplate { + public static final field ASSISTANT_END Ljava/lang/String; + public static final field ASSISTANT_START Ljava/lang/String; + public static final field Companion Lsk/ainet/apps/kllama/chat/ApertusChatTemplate$Companion; + public static final field DEVELOPER_END Ljava/lang/String; + public static final field DEVELOPER_START Ljava/lang/String; + public static final field INNER_PREFIX Ljava/lang/String; + public static final field INNER_SUFFIX Ljava/lang/String; + public static final field SYSTEM_END Ljava/lang/String; + public static final field SYSTEM_START Ljava/lang/String; + public static final field TOOLS_PREFIX Ljava/lang/String; + public static final field TOOLS_SUFFIX Ljava/lang/String; + public static final field USER_END Ljava/lang/String; + public static final field USER_START Ljava/lang/String; + public fun ()V + public fun (Ljava/lang/String;Ljava/lang/String;Z)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; + public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; + public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; + public final fun withCurrentDate (Ljava/lang/String;)Lsk/ainet/apps/kllama/chat/ApertusChatTemplate; +} + +public final class sk/ainet/apps/kllama/chat/ApertusChatTemplate$Companion { +} + +public final class sk/ainet/apps/kllama/chat/ApertusToolCallParserStrategy : sk/ainet/apps/kllama/chat/ToolCallParserStrategy { + public static final field INSTANCE Lsk/ainet/apps/kllama/chat/ApertusToolCallParserStrategy; + public fun containsToolCall (Ljava/lang/String;)Z + public fun getFormatName ()Ljava/lang/String; + public fun parse (Ljava/lang/String;)Ljava/util/List; +} + +public final class sk/ainet/apps/kllama/chat/ApertusToolCallingSupport : sk/ainet/apps/kllama/chat/ToolCallingSupport { + public fun ()V + public fun createChatTemplate ()Lsk/ainet/apps/kllama/chat/ChatTemplate; + public fun getFamily ()Ljava/lang/String; + public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun supports (Lsk/ainet/apps/kllama/chat/ModelMetadata;)Z + public fun toolCallingMode (Lsk/ainet/apps/kllama/chat/ModelMetadata;)Lsk/ainet/apps/kllama/chat/ToolCallingMode; +} + public final class sk/ainet/apps/kllama/chat/ChatMLTemplate : sk/ainet/apps/kllama/chat/ChatTemplate { public fun ()V public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; } public final class sk/ainet/apps/kllama/chat/ChatMLToolCallingSupport : sk/ainet/apps/kllama/chat/ToolCallingSupport { @@ -112,24 +162,83 @@ public final class sk/ainet/apps/kllama/chat/ChatRole : java/lang/Enum { public static fun values ()[Lsk/ainet/apps/kllama/chat/ChatRole; } +public final class sk/ainet/apps/kllama/chat/ChatSession { + public static final field Companion Lsk/ainet/apps/kllama/chat/ChatSession$Companion; + public static final field DEFAULT_SYSTEM_PROMPT Ljava/lang/String; + public fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ModelMetadata;Ljava/lang/String;Ljava/lang/String;)V + public synthetic fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ModelMetadata;Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun createAgentLoop (Lsk/ainet/apps/kllama/chat/ToolRegistry;IF)Lsk/ainet/apps/kllama/chat/AgentLoop; + public static synthetic fun createAgentLoop$default (Lsk/ainet/apps/kllama/chat/ChatSession;Lsk/ainet/apps/kllama/chat/ToolRegistry;IFILjava/lang/Object;)Lsk/ainet/apps/kllama/chat/AgentLoop; + public final fun decode (I)Ljava/lang/String; + public final fun encode (Ljava/lang/String;)[I + public final fun getChatTemplate ()Lsk/ainet/apps/kllama/chat/ChatTemplate; + public final fun getDefaultSystemPrompt ()Ljava/lang/String; + public final fun getMetadata ()Lsk/ainet/apps/kllama/chat/ModelMetadata; + public final fun getProviderFamily ()Ljava/lang/String; + public final fun getRuntime ()Lsk/ainet/apps/llm/InferenceRuntime; + public final fun getTokenizer ()Lsk/ainet/apps/llm/Tokenizer; + public final fun runSingleTurn (Ljava/lang/String;Ljava/util/List;IFLjava/lang/String;Lsk/ainet/apps/kllama/chat/AgentListener;)Ljava/lang/String; + public static synthetic fun runSingleTurn$default (Lsk/ainet/apps/kllama/chat/ChatSession;Ljava/lang/String;Ljava/util/List;IFLjava/lang/String;Lsk/ainet/apps/kllama/chat/AgentListener;ILjava/lang/Object;)Ljava/lang/String; +} + +public final class sk/ainet/apps/kllama/chat/ChatSession$Companion { +} + public abstract interface class sk/ainet/apps/kllama/chat/ChatTemplate { public abstract fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; public static synthetic fun apply$default (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/util/List;Ljava/util/List;ZILjava/lang/Object;)Ljava/lang/String; public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; } public final class sk/ainet/apps/kllama/chat/ChatTemplate$DefaultImpls { public static synthetic fun apply$default (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/util/List;Ljava/util/List;ZILjava/lang/Object;)Ljava/lang/String; public static fun containsToolCall (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/lang/String;)Z + public static fun parseThinkingBlocks (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/lang/String;)Ljava/util/List; public static fun parseToolCalls (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/lang/String;)Ljava/util/List; + public static fun stripThinking (Lsk/ainet/apps/kllama/chat/ChatTemplate;Ljava/lang/String;)Ljava/lang/String; +} + +public final class sk/ainet/apps/kllama/chat/Gemma4ChatTemplate : sk/ainet/apps/kllama/chat/ChatTemplate { + public static final field BOS Ljava/lang/String; + public static final field CHANNEL_CLOSE Ljava/lang/String; + public static final field CHANNEL_OPEN Ljava/lang/String; + public static final field QUOTE Ljava/lang/String; + public static final field TOOL_CALL_CLOSE Ljava/lang/String; + public static final field TOOL_CALL_OPEN Ljava/lang/String; + public static final field TOOL_CLOSE Ljava/lang/String; + public static final field TOOL_OPEN Ljava/lang/String; + public static final field TOOL_RESPONSE_CLOSE Ljava/lang/String; + public static final field TOOL_RESPONSE_OPEN Ljava/lang/String; + public static final field TURN_CLOSE Ljava/lang/String; + public fun ()V + public fun (Z)V + public synthetic fun (ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; + public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; + public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; +} + +public final class sk/ainet/apps/kllama/chat/Gemma4ToolCallingSupport : sk/ainet/apps/kllama/chat/ToolCallingSupport { + public fun ()V + public fun createChatTemplate ()Lsk/ainet/apps/kllama/chat/ChatTemplate; + public fun getFamily ()Ljava/lang/String; + public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun supports (Lsk/ainet/apps/kllama/chat/ModelMetadata;)Z + public fun toolCallingMode (Lsk/ainet/apps/kllama/chat/ModelMetadata;)Lsk/ainet/apps/kllama/chat/ToolCallingMode; } public final class sk/ainet/apps/kllama/chat/GemmaChatTemplate : sk/ainet/apps/kllama/chat/ChatTemplate { public fun ()V public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; } public final class sk/ainet/apps/kllama/chat/GemmaToolCallParserStrategy : sk/ainet/apps/kllama/chat/ToolCallParserStrategy { @@ -159,13 +268,19 @@ public final class sk/ainet/apps/kllama/chat/GenericToolCallingSupport : sk/aine public final class sk/ainet/apps/kllama/chat/Llama3ChatTemplate : sk/ainet/apps/kllama/chat/ChatTemplate { public fun ()V + public fun (Lsk/ainet/apps/kllama/chat/Llama3ToolFormat;)V + public synthetic fun (Lsk/ainet/apps/kllama/chat/Llama3ToolFormat;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; } public final class sk/ainet/apps/kllama/chat/Llama3ToolCallingSupport : sk/ainet/apps/kllama/chat/ToolCallingSupport { public fun ()V + public fun (Lsk/ainet/apps/kllama/chat/Llama3ToolFormat;)V + public synthetic fun (Lsk/ainet/apps/kllama/chat/Llama3ToolFormat;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun createChatTemplate ()Lsk/ainet/apps/kllama/chat/ChatTemplate; public fun getFamily ()Ljava/lang/String; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; @@ -173,6 +288,14 @@ public final class sk/ainet/apps/kllama/chat/Llama3ToolCallingSupport : sk/ainet public fun toolCallingMode (Lsk/ainet/apps/kllama/chat/ModelMetadata;)Lsk/ainet/apps/kllama/chat/ToolCallingMode; } +public final class sk/ainet/apps/kllama/chat/Llama3ToolFormat : java/lang/Enum { + public static final field FUNCTION_TAG Lsk/ainet/apps/kllama/chat/Llama3ToolFormat; + public static final field JSON Lsk/ainet/apps/kllama/chat/Llama3ToolFormat; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/apps/kllama/chat/Llama3ToolFormat; + public static fun values ()[Lsk/ainet/apps/kllama/chat/Llama3ToolFormat; +} + public final class sk/ainet/apps/kllama/chat/ModelMetadata { public fun ()V public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;)V @@ -198,7 +321,9 @@ public final class sk/ainet/apps/kllama/chat/QwenChatTemplate : sk/ainet/apps/kl public fun ()V public fun apply (Ljava/util/List;Ljava/util/List;Z)Ljava/lang/String; public fun containsToolCall (Ljava/lang/String;)Z + public fun parseThinkingBlocks (Ljava/lang/String;)Ljava/util/List; public fun parseToolCalls (Ljava/lang/String;)Ljava/util/List; + public fun stripThinking (Ljava/lang/String;)Ljava/lang/String; } public final class sk/ainet/apps/kllama/chat/QwenToolCallingSupport : sk/ainet/apps/kllama/chat/ToolCallingSupport { @@ -260,6 +385,29 @@ public abstract interface class sk/ainet/apps/kllama/chat/ToolCallParserStrategy public abstract fun parse (Ljava/lang/String;)Ljava/util/List; } +public abstract class sk/ainet/apps/kllama/chat/ToolCallValidationResult { +} + +public final class sk/ainet/apps/kllama/chat/ToolCallValidationResult$Invalid : sk/ainet/apps/kllama/chat/ToolCallValidationResult { + public fun (Ljava/lang/String;)V + public final fun component1 ()Ljava/lang/String; + public final fun copy (Ljava/lang/String;)Lsk/ainet/apps/kllama/chat/ToolCallValidationResult$Invalid; + public static synthetic fun copy$default (Lsk/ainet/apps/kllama/chat/ToolCallValidationResult$Invalid;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/apps/kllama/chat/ToolCallValidationResult$Invalid; + public fun equals (Ljava/lang/Object;)Z + public final fun getReason ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/apps/kllama/chat/ToolCallValidationResult$Valid : sk/ainet/apps/kllama/chat/ToolCallValidationResult { + public static final field INSTANCE Lsk/ainet/apps/kllama/chat/ToolCallValidationResult$Valid; +} + +public final class sk/ainet/apps/kllama/chat/ToolCallValidator { + public static final field INSTANCE Lsk/ainet/apps/kllama/chat/ToolCallValidator; + public final fun validate (Lsk/ainet/apps/kllama/chat/ToolCall;Lsk/ainet/apps/kllama/chat/ToolDefinition;)Lsk/ainet/apps/kllama/chat/ToolCallValidationResult; +} + public final class sk/ainet/apps/kllama/chat/ToolCallingMode : java/lang/Enum { public static final field GENERIC Lsk/ainet/apps/kllama/chat/ToolCallingMode; public static final field NATIVE Lsk/ainet/apps/kllama/chat/ToolCallingMode; diff --git a/llm-agent/src/commonMain/kotlin/sk/ainet/apps/kllama/chat/ToolCallParser.kt b/llm-agent/src/commonMain/kotlin/sk/ainet/apps/kllama/chat/ToolCallParser.kt index 3dc387aa..fb41c5a6 100644 --- a/llm-agent/src/commonMain/kotlin/sk/ainet/apps/kllama/chat/ToolCallParser.kt +++ b/llm-agent/src/commonMain/kotlin/sk/ainet/apps/kllama/chat/ToolCallParser.kt @@ -132,7 +132,7 @@ internal class Llama31ToolCallParserStrategy : ToolCallParserStrategy { override val formatName: String = "llama3-json" override fun parse(text: String): List { - val candidate = stripPythonTag(text).trim() + val candidate = stripCodeFence(stripPythonTag(text)).trim() if (!candidate.startsWith("{") || !candidate.contains("\"name\"")) return emptyList() val firstObject = extractFirstJsonObject(candidate) ?: return emptyList() val parsed = ToolCallParser.parseJsonToolCall(firstObject) ?: return emptyList() @@ -140,7 +140,7 @@ internal class Llama31ToolCallParserStrategy : ToolCallParserStrategy { } override fun containsToolCall(text: String): Boolean { - val candidate = stripPythonTag(text).trim() + val candidate = stripCodeFence(stripPythonTag(text)).trim() return candidate.startsWith("{") && candidate.contains("\"name\"") && (candidate.contains("\"arguments\"") || candidate.contains("\"parameters\"")) @@ -151,6 +151,19 @@ internal class Llama31ToolCallParserStrategy : ToolCallParserStrategy { return if (trimmed.startsWith("<|python_tag|>")) trimmed.removePrefix("<|python_tag|>") else text } + // Llama-3 instruct models sometimes wrap their tool-call JSON in a markdown + // code fence (```...``` or ```json...```), even though the prompt asks for + // bare JSON. Peel one layer of fencing so the parser still sees the object. + private fun stripCodeFence(text: String): String { + val trimmed = text.trim() + if (!trimmed.startsWith("```")) return text + val firstNewline = trimmed.indexOf('\n') + if (firstNewline == -1) return text + val withoutOpening = trimmed.substring(firstNewline + 1) + val closingIdx = withoutOpening.lastIndexOf("```") + return if (closingIdx >= 0) withoutOpening.substring(0, closingIdx) else withoutOpening + } + /** Find the first `{...}` block at the start of [text], respecting brace nesting and string literals. */ private fun extractFirstJsonObject(text: String): String? { if (!text.startsWith("{")) return null diff --git a/llm-agent/src/commonTest/kotlin/sk/ainet/apps/kllama/chat/ToolCallParserTest.kt b/llm-agent/src/commonTest/kotlin/sk/ainet/apps/kllama/chat/ToolCallParserTest.kt index 2e4551d9..2e713414 100644 --- a/llm-agent/src/commonTest/kotlin/sk/ainet/apps/kllama/chat/ToolCallParserTest.kt +++ b/llm-agent/src/commonTest/kotlin/sk/ainet/apps/kllama/chat/ToolCallParserTest.kt @@ -191,4 +191,39 @@ class ToolCallParserTest { fun containsToolCallDetectsParametersKey() { assertTrue(ToolCallParser.containsToolCall("""{"name": "x", "parameters": {}}""")) } + + // --- Llama 3.2 sometimes wraps its JSON in a markdown code fence --- + + @Test + fun parseLlama32JsonInsideTripleBacktickFence() { + val text = """ + ``` + {"name": "list_files", "parameters": {"path": "/tmp"}} + ``` + """.trimIndent() + val calls = ToolCallParser.parse(text) + assertEquals(1, calls.size) + assertEquals("list_files", calls[0].name) + assertEquals("/tmp", calls[0].arguments["path"]?.toString()?.trim('"')) + } + + @Test + fun parseLlama32JsonInsideJsonTaggedFence() { + val text = """ + ```json + {"name": "calc", "parameters": {"x": 1}} + ``` + """.trimIndent() + val calls = ToolCallParser.parse(text) + assertEquals(1, calls.size) + assertEquals("calc", calls[0].name) + } + + @Test + fun containsToolCallDetectsFencedJson() { + val text = """``` +{"name": "x", "parameters": {}} +```""" + assertTrue(ToolCallParser.containsToolCall(text)) + } } diff --git a/llm-apps/kbert-cli/build.gradle.kts b/llm-apps/kbert-cli/build.gradle.kts index 8eac6f03..0a59dc43 100644 --- a/llm-apps/kbert-cli/build.gradle.kts +++ b/llm-apps/kbert-cli/build.gradle.kts @@ -1,6 +1,11 @@ plugins { kotlin("jvm") alias(libs.plugins.shadow) + application +} + +application { + mainClass.set("sk.ainet.apps.bert.cli.MainKt") } dependencies { diff --git a/llm-core/api/jvm/llm-core.api b/llm-core/api/jvm/llm-core.api index 0e77d91f..947716ab 100644 --- a/llm-core/api/jvm/llm-core.api +++ b/llm-core/api/jvm/llm-core.api @@ -5,6 +5,7 @@ public abstract class sk/ainet/apps/llm/DecoderRuntime : sk/ainet/apps/llm/Infer protected abstract fun embedToken (I)Lsk/ainet/lang/tensor/Tensor; protected final fun expectFloatBuffer (Lsk/ainet/lang/tensor/Tensor;)[F public fun forward (I)Lsk/ainet/lang/tensor/Tensor; + public fun forwardBatched ([I)Lsk/ainet/lang/tensor/Tensor; public fun generate ([IIFLkotlin/jvm/functions/Function1;)V protected abstract fun getBosToken ()I protected abstract fun getDim ()I @@ -46,8 +47,8 @@ public final class sk/ainet/apps/llm/GGUFModelInfo { } public final class sk/ainet/apps/llm/GenerateExtensionsKt { - public static final fun generate (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun generate$default (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun generate (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lsk/ainet/apps/llm/PrefillStrategy;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun generate$default (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lsk/ainet/apps/llm/PrefillStrategy;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun sampleFromTensor (Lsk/ainet/lang/tensor/Tensor;FLkotlin/random/Random;)I public static synthetic fun sampleFromTensor$default (Lsk/ainet/lang/tensor/Tensor;FLkotlin/random/Random;ILjava/lang/Object;)I } @@ -75,9 +76,14 @@ public final class sk/ainet/apps/llm/HybridTransformerBlock : sk/ainet/lang/nn/M public abstract interface class sk/ainet/apps/llm/InferenceRuntime { public abstract fun forward (I)Lsk/ainet/lang/tensor/Tensor; + public fun forwardBatched ([I)Lsk/ainet/lang/tensor/Tensor; public abstract fun reset ()V } +public final class sk/ainet/apps/llm/InferenceRuntime$DefaultImpls { + public static fun forwardBatched (Lsk/ainet/apps/llm/InferenceRuntime;[I)Lsk/ainet/lang/tensor/Tensor; +} + public abstract interface class sk/ainet/apps/llm/KvCache { public abstract fun getKey (IIII)F public abstract fun getKvDim ()I @@ -133,6 +139,7 @@ public final class sk/ainet/apps/llm/OptimizedLLMRuntime : sk/ainet/apps/llm/Inf public final fun compileWith (ILsk/ainet/compile/opt/GraphOptimizationPipeline;)Ljava/util/List; public static synthetic fun compileWith$default (Lsk/ainet/apps/llm/OptimizedLLMRuntime;ILsk/ainet/compile/opt/GraphOptimizationPipeline;ILjava/lang/Object;)Ljava/util/List; public fun forward (I)Lsk/ainet/lang/tensor/Tensor; + public fun forwardBatched ([I)Lsk/ainet/lang/tensor/Tensor; public final fun generate ([IIFLkotlin/jvm/functions/Function1;)V public final fun getBos ()I public final fun getDim ()I @@ -151,9 +158,42 @@ public final class sk/ainet/apps/llm/OptimizedLLMRuntimeInternalKt { public static final fun getLLMOptimizationPipeline ()Lsk/ainet/compile/opt/GraphOptimizationPipeline; } +public abstract interface class sk/ainet/apps/llm/PrefillStrategy { +} + +public final class sk/ainet/apps/llm/PrefillStrategy$Autoregressive : sk/ainet/apps/llm/PrefillStrategy { + public static final field INSTANCE Lsk/ainet/apps/llm/PrefillStrategy$Autoregressive; +} + +public final class sk/ainet/apps/llm/PrefillStrategy$Batched : sk/ainet/apps/llm/PrefillStrategy { + public fun ()V + public fun (I)V + public synthetic fun (IILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()I + public final fun copy (I)Lsk/ainet/apps/llm/PrefillStrategy$Batched; + public static synthetic fun copy$default (Lsk/ainet/apps/llm/PrefillStrategy$Batched;IILjava/lang/Object;)Lsk/ainet/apps/llm/PrefillStrategy$Batched; + public fun equals (Ljava/lang/Object;)Z + public final fun getMaxBatch ()I + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/apps/llm/RopeType : java/lang/Enum { + public static final field Companion Lsk/ainet/apps/llm/RopeType$Companion; + public static final field HALF_SPLIT Lsk/ainet/apps/llm/RopeType; + public static final field INTERLEAVED Lsk/ainet/apps/llm/RopeType; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/apps/llm/RopeType; + public static fun values ()[Lsk/ainet/apps/llm/RopeType; +} + +public final class sk/ainet/apps/llm/RopeType$Companion { + public final fun forArchitecture (Ljava/lang/String;)Lsk/ainet/apps/llm/RopeType; +} + public final class sk/ainet/apps/llm/RopeUtilsKt { - public static final fun applyRopeRotation ([FIIIIF[F[FILjava/lang/Float;)V - public static synthetic fun applyRopeRotation$default ([FIIIIF[F[FILjava/lang/Float;ILjava/lang/Object;)V + public static final fun applyRopeRotation ([FIIIIF[F[FILjava/lang/Float;Lsk/ainet/apps/llm/RopeType;)V + public static synthetic fun applyRopeRotation$default ([FIIIIF[F[FILjava/lang/Float;Lsk/ainet/apps/llm/RopeType;ILjava/lang/Object;)V public static final fun ropeCos (IIIF)F public static synthetic fun ropeCos$default (IIIFILjava/lang/Object;)F public static final fun ropeFrequency (IIIF)F @@ -233,6 +273,11 @@ public final class sk/ainet/apps/llm/compile/LLMFusionPass : sk/ainet/compile/op public fun getName ()Ljava/lang/String; } +public final class sk/ainet/apps/llm/diag/PlatformDiag_jvmKt { + public static final fun dumpStats (Ljava/lang/String;Lsk/ainet/lang/tensor/Tensor;)V + public static final fun envFlag (Ljava/lang/String;)Z +} + public final class sk/ainet/apps/llm/graph/LLMFusedOpHandlers { public static final field INSTANCE Lsk/ainet/apps/llm/graph/LLMFusedOpHandlers; public final fun registerAll ()V @@ -248,7 +293,7 @@ public final class sk/ainet/apps/llm/tokenizer/BPEStrategy : sk/ainet/apps/llm/T public final class sk/ainet/apps/llm/tokenizer/GGUFTokenizer : sk/ainet/apps/llm/Tokenizer { public static final field Companion Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion; - public synthetic fun (Ljava/util/List;[FIIILsk/ainet/apps/llm/TokenizerStrategy;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Ljava/util/List;[FIIILsk/ainet/apps/llm/TokenizerStrategy;[IZLkotlin/jvm/internal/DefaultConstructorMarker;)V public fun decode (I)Ljava/lang/String; public fun decode ([I)Ljava/lang/String; public fun encode (Ljava/lang/String;)[I @@ -267,6 +312,8 @@ public final class sk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion { public static synthetic fun fromRandomAccessSource$default (Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion;Lsk/ainet/io/RandomAccessSource;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; public final fun fromSource (Lkotlinx/io/Source;Z)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; public static synthetic fun fromSource$default (Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion;Lkotlinx/io/Source;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; + public final fun fromStreamingFields (Ljava/util/Map;Z)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; + public static synthetic fun fromStreamingFields$default (Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion;Ljava/util/Map;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; public final fun fromTokenizerJson (Ljava/lang/String;Z)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; public static synthetic fun fromTokenizerJson$default (Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer$Companion;Ljava/lang/String;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; } @@ -285,6 +332,24 @@ public final class sk/ainet/apps/llm/tokenizer/HuggingFaceBPETokenizerKt { public static synthetic fun createHuggingFaceBPETokenizerFromJson$default (Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/HuggingFaceBPETokenizer; } +public final class sk/ainet/apps/llm/tokenizer/SentencePieceSpecialTokens : sk/ainet/apps/llm/Tokenizer { + public static final field Companion Lsk/ainet/apps/llm/tokenizer/SentencePieceSpecialTokens$Companion; + public fun (Lsk/ainet/io/tokenizer/SentencePieceTokenizer;Ljava/util/Map;Ljava/lang/Integer;Ljava/lang/Integer;)V + public synthetic fun (Lsk/ainet/io/tokenizer/SentencePieceTokenizer;Ljava/util/Map;Ljava/lang/Integer;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun decode (I)Ljava/lang/String; + public fun decode ([I)Ljava/lang/String; + public fun encode (Ljava/lang/String;)[I + public fun getBosTokenId ()I + public fun getEosTokenId ()I + public fun getVocabSize ()I +} + +public final class sk/ainet/apps/llm/tokenizer/SentencePieceSpecialTokens$Companion { + public final fun fromGgufFields (Ljava/util/Map;)Lsk/ainet/apps/llm/Tokenizer; + public final fun fromTokenizerJson (Ljava/lang/String;Ljava/lang/String;)Lsk/ainet/apps/llm/Tokenizer; + public static synthetic fun fromTokenizerJson$default (Lsk/ainet/apps/llm/tokenizer/SentencePieceSpecialTokens$Companion;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/apps/llm/Tokenizer; +} + public final class sk/ainet/apps/llm/tokenizer/SentencePieceStrategy : sk/ainet/apps/llm/TokenizerStrategy { public static final field INSTANCE Lsk/ainet/apps/llm/tokenizer/SentencePieceStrategy; public fun getSpaceMarker ()Ljava/lang/String; @@ -295,12 +360,10 @@ public final class sk/ainet/apps/llm/tokenizer/SentencePieceStrategy : sk/ainet/ public final class sk/ainet/apps/llm/tokenizer/TokenizerFactory { public static final field INSTANCE Lsk/ainet/apps/llm/tokenizer/TokenizerFactory; - public final fun fromGGUF (Lsk/ainet/io/RandomAccessSource;Z)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; - public static synthetic fun fromGGUF$default (Lsk/ainet/apps/llm/tokenizer/TokenizerFactory;Lsk/ainet/io/RandomAccessSource;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; - public final fun fromHuggingFace (Ljava/lang/String;Ljava/lang/String;)Lsk/ainet/apps/llm/Tokenizer; - public static synthetic fun fromHuggingFace$default (Lsk/ainet/apps/llm/tokenizer/TokenizerFactory;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/apps/llm/Tokenizer; - public final fun fromTokenizerJson (Ljava/lang/String;Z)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; - public static synthetic fun fromTokenizerJson$default (Lsk/ainet/apps/llm/tokenizer/TokenizerFactory;Ljava/lang/String;ZILjava/lang/Object;)Lsk/ainet/apps/llm/tokenizer/GGUFTokenizer; + public final fun fromGgufFields (Ljava/util/Map;)Lsk/ainet/apps/llm/Tokenizer; + public final fun fromGgufSource (Lsk/ainet/io/RandomAccessSource;)Lsk/ainet/apps/llm/Tokenizer; + public final fun fromTokenizerJsonString (Ljava/lang/String;Ljava/lang/String;)Lsk/ainet/apps/llm/Tokenizer; + public static synthetic fun fromTokenizerJsonString$default (Lsk/ainet/apps/llm/tokenizer/TokenizerFactory;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/apps/llm/Tokenizer; } public final class sk/ainet/apps/llm/tokenizer/UnknownStrategy : sk/ainet/apps/llm/TokenizerStrategy { @@ -475,8 +538,8 @@ public final class sk/ainet/lang/nn/dsl/ATTENTION$DefaultImpls { } public final class sk/ainet/lang/nn/dsl/AttentionImpl : sk/ainet/lang/nn/dsl/ATTENTION { - public fun (Lsk/ainet/context/ExecutionContext;IIIZZZLjava/lang/String;Ljava/lang/Integer;)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;IIIZZZLjava/lang/String;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/context/ExecutionContext;IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Ljava/lang/Integer;)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun create ()Lsk/ainet/lang/nn/transformer/MultiHeadAttention; public fun getExecutionContext ()Lsk/ainet/context/ExecutionContext; public fun kvCache (III)V @@ -493,16 +556,16 @@ public final class sk/ainet/lang/nn/dsl/TransformerDslKt { public static final fun geGluFFN (Lsk/ainet/lang/nn/dsl/StageImpl;IILjava/lang/String;)V public static synthetic fun geGluFFN$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IILjava/lang/String;ILjava/lang/Object;)V public static synthetic fun geGluFFN$default (Lsk/ainet/lang/nn/dsl/StageImpl;IILjava/lang/String;ILjava/lang/Object;)V - public static final fun multiHeadAttention (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IIIZZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;)V - public static final fun multiHeadAttention (Lsk/ainet/lang/nn/dsl/StageImpl;IIIZZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun multiHeadAttention$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IIIZZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static synthetic fun multiHeadAttention$default (Lsk/ainet/lang/nn/dsl/StageImpl;IIIZZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun multiHeadAttention (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IIIZZFZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;)V + public static final fun multiHeadAttention (Lsk/ainet/lang/nn/dsl/StageImpl;IIIZZZFLjava/lang/Float;ZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun multiHeadAttention$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IIIZZFZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun multiHeadAttention$default (Lsk/ainet/lang/nn/dsl/StageImpl;IIIZZZFLjava/lang/Float;ZZLjava/lang/String;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun residual (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;)V public static final fun residual (Lsk/ainet/lang/nn/dsl/StageImpl;)V - public static final fun rmsNorm (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IFLjava/lang/String;)V - public static final fun rmsNorm (Lsk/ainet/lang/nn/dsl/StageImpl;IFLjava/lang/String;)V - public static synthetic fun rmsNorm$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IFLjava/lang/String;ILjava/lang/Object;)V - public static synthetic fun rmsNorm$default (Lsk/ainet/lang/nn/dsl/StageImpl;IFLjava/lang/String;ILjava/lang/Object;)V + public static final fun rmsNorm (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IFLjava/lang/String;Z)V + public static final fun rmsNorm (Lsk/ainet/lang/nn/dsl/StageImpl;IFLjava/lang/String;Z)V + public static synthetic fun rmsNorm$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IFLjava/lang/String;ZILjava/lang/Object;)V + public static synthetic fun rmsNorm$default (Lsk/ainet/lang/nn/dsl/StageImpl;IFLjava/lang/String;ZILjava/lang/Object;)V public static final fun swiGluFFN (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IILjava/lang/String;)V public static final fun swiGluFFN (Lsk/ainet/lang/nn/dsl/StageImpl;IILjava/lang/String;)V public static synthetic fun swiGluFFN$default (Lsk/ainet/lang/nn/dsl/NeuralNetworkDslImpl;IILjava/lang/String;ILjava/lang/Object;)V @@ -513,6 +576,21 @@ public final class sk/ainet/lang/nn/dsl/TransformerDslKt { public static synthetic fun xielu$default (Lsk/ainet/lang/nn/dsl/StageImpl;Ljava/lang/String;ILjava/lang/Object;)V } +public abstract interface class sk/ainet/lang/nn/dsl/decoder/DecoderModelMetadata { + public abstract fun getBlockCount ()I + public abstract fun getBosTokenId ()I + public abstract fun getContextLength ()I + public abstract fun getEmbeddingLength ()I + public abstract fun getEosTokenId ()I + public abstract fun getFeedForwardLength ()I + public abstract fun getHeadCount ()I + public abstract fun getKvHeadCount ()I + public abstract fun getRmsNormEps ()F + public abstract fun getRopeDimensionCount ()Ljava/lang/Integer; + public abstract fun getRopeFreqBase ()F + public abstract fun getVocabSize ()I +} + public final class sk/ainet/lang/nn/layers/Embedding : sk/ainet/lang/nn/DualModule, sk/ainet/lang/nn/topology/ModuleParameters { public static final field Companion Lsk/ainet/lang/nn/layers/Embedding$Companion; public fun (IILsk/ainet/lang/tensor/Tensor;Ljava/lang/Integer;Ljava/lang/String;)V @@ -570,8 +648,8 @@ public abstract interface class sk/ainet/lang/nn/normalization/FusedRmsNormOps { } public final class sk/ainet/lang/nn/normalization/RMSNormalization : sk/ainet/lang/nn/Module, sk/ainet/lang/nn/topology/ModuleParameters { - public fun ([IDLjava/lang/String;Lsk/ainet/lang/tensor/Tensor;)V - public synthetic fun ([IDLjava/lang/String;Lsk/ainet/lang/tensor/Tensor;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun ([IDLjava/lang/String;Lsk/ainet/lang/tensor/Tensor;Z)V + public synthetic fun ([IDLjava/lang/String;Lsk/ainet/lang/tensor/Tensor;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun forward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/context/ExecutionContext;)Lsk/ainet/lang/tensor/Tensor; public fun getModules ()Ljava/util/List; public fun getName ()Ljava/lang/String; @@ -624,8 +702,9 @@ public final class sk/ainet/lang/nn/transformer/LinearProjectionKt { } public final class sk/ainet/lang/nn/transformer/MultiHeadAttention : sk/ainet/lang/nn/Module, sk/ainet/lang/nn/topology/ModuleParameters { - public fun (IIIZZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;)V - public synthetic fun (IIIZZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;)V + public synthetic fun (IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getAttentionScale ()Ljava/lang/Float; public final fun getBias ()Z public final fun getCausal ()Z public final fun getDim ()I @@ -641,12 +720,30 @@ public final class sk/ainet/lang/nn/transformer/MultiHeadAttention : sk/ainet/la public final fun getQDim ()I public final fun getQNorm ()Lsk/ainet/lang/nn/normalization/RMSNormalization; public final fun getQkNorm ()Z + public final fun getQkNormEps ()D + public final fun getQkNormUnitOffset ()Z public final fun getRope ()Lsk/ainet/lang/nn/transformer/RoPE; public final fun getSlidingWindow ()Ljava/lang/Integer; + public final fun getVNormNoScale ()Z public final fun setKvCache (Lsk/ainet/lang/nn/transformer/KVCache;)V public final fun setRope (Lsk/ainet/lang/nn/transformer/RoPE;)V } +public final class sk/ainet/lang/nn/transformer/MultiHeadAttentionDiag { + public static final field INSTANCE Lsk/ainet/lang/nn/transformer/MultiHeadAttentionDiag; + public final fun getShouldDumpThisCall ()Z + public final fun setShouldDumpThisCall (Z)V +} + +public final class sk/ainet/lang/nn/transformer/OwnerReadOnlyKVCache : sk/ainet/lang/nn/transformer/KVCache { + public fun (Lsk/ainet/lang/nn/transformer/PositionalKVCache;Ljava/lang/String;)V + public synthetic fun (Lsk/ainet/lang/nn/transformer/PositionalKVCache;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getDelegate ()Lsk/ainet/lang/nn/transformer/PositionalKVCache; + public fun getPosition ()I + public fun reset ()V + public fun update (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/context/ExecutionContext;)Lkotlin/Pair; +} + public final class sk/ainet/lang/nn/transformer/PaddedSharedPositionalKVCache : sk/ainet/lang/nn/transformer/KVCache { public fun (Lsk/ainet/lang/nn/transformer/PositionalKVCache;ILjava/lang/String;)V public synthetic fun (Lsk/ainet/lang/nn/transformer/PositionalKVCache;ILjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/llm-inference/apertus/api/jvm/apertus.api b/llm-inference/apertus/api/jvm/apertus.api index aed125ea..59b4770f 100644 --- a/llm-inference/apertus/api/jvm/apertus.api +++ b/llm-inference/apertus/api/jvm/apertus.api @@ -1,29 +1,9 @@ -public abstract interface class sk/ainet/models/apertus/ApertusAttentionBackend { - public abstract fun attention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public fun batchAttention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public abstract fun reset ()V -} - -public final class sk/ainet/models/apertus/ApertusAttentionBackend$DefaultImpls { - public static fun batchAttention (Lsk/ainet/models/apertus/ApertusAttentionBackend;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; -} - public final class sk/ainet/models/apertus/ApertusConfigParser { public static final field INSTANCE Lsk/ainet/models/apertus/ApertusConfigParser; public final fun isTiedEmbeddings (Ljava/lang/String;)Z public final fun parse (Ljava/lang/String;)Lsk/ainet/models/apertus/ApertusModelMetadata; } -public final class sk/ainet/models/apertus/ApertusCpuAttentionBackend : sk/ainet/models/apertus/ApertusAttentionBackend { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusModelMetadata;Lkotlin/reflect/KClass;[FLsk/ainet/apps/llm/KvCache;F)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusModelMetadata;Lkotlin/reflect/KClass;[FLsk/ainet/apps/llm/KvCache;FILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;F)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;FILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun attention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public fun batchAttention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public fun reset ()V -} - public final class sk/ainet/models/apertus/ApertusHfTensorNameMapper { public static final field INSTANCE Lsk/ainet/models/apertus/ApertusHfTensorNameMapper; public final fun toCanonical (Ljava/lang/String;)Ljava/lang/String; @@ -60,6 +40,10 @@ public final class sk/ainet/models/apertus/ApertusLayerWeights { public fun toString ()Ljava/lang/String; } +public final class sk/ainet/models/apertus/ApertusMemSegConverterKt { + public static final fun convertApertusWeightsToMemSeg (Lsk/ainet/models/apertus/ApertusWeights;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;)Lsk/ainet/models/apertus/ApertusWeights; +} + public final class sk/ainet/models/apertus/ApertusModelMetadata { public fun (Ljava/lang/String;IIIIIILjava/lang/Integer;IFFZLjava/lang/String;ZII)V public synthetic fun (Ljava/lang/String;IIIIIILjava/lang/Integer;IFFZLjava/lang/String;ZIIILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -203,12 +187,6 @@ public final class sk/ainet/models/apertus/ApertusQuantizedLayerWeights { public fun toString ()Ljava/lang/String; } -public final class sk/ainet/models/apertus/ApertusQuantizedRuntime : sk/ainet/apps/llm/DecoderRuntime { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusQuantizedRuntimeWeights;Lsk/ainet/models/apertus/ApertusAttentionBackend;FLkotlin/random/Random;)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusQuantizedRuntimeWeights;Lsk/ainet/models/apertus/ApertusAttentionBackend;FLkotlin/random/Random;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun getWeights ()Lsk/ainet/models/apertus/ApertusQuantizedRuntimeWeights; -} - public final class sk/ainet/models/apertus/ApertusQuantizedRuntimeWeights { public fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/models/apertus/QuantizedTensor;Lsk/ainet/lang/tensor/Tensor;)V public synthetic fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/models/apertus/QuantizedTensor;Lsk/ainet/lang/tensor/Tensor;ILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -254,17 +232,6 @@ public final class sk/ainet/models/apertus/ApertusQuantizedWeights { public fun toString ()Ljava/lang/String; } -public final class sk/ainet/models/apertus/ApertusRuntime : sk/ainet/apps/llm/DecoderRuntime { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusRuntimeWeights;Lsk/ainet/models/apertus/ApertusAttentionBackend;Lkotlin/reflect/KClass;FLkotlin/random/Random;)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/apertus/ApertusRuntimeWeights;Lsk/ainet/models/apertus/ApertusAttentionBackend;Lkotlin/reflect/KClass;FLkotlin/random/Random;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun getWeights ()Lsk/ainet/models/apertus/ApertusRuntimeWeights; -} - -public final class sk/ainet/models/apertus/ApertusRuntimeKt { - public static final fun softplus (F)F - public static final fun xielu ([FLsk/ainet/models/apertus/ApertusXIELUParams;)V -} - public final class sk/ainet/models/apertus/ApertusRuntimeWeights { public fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Z)V public synthetic fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -359,23 +326,34 @@ public final class sk/ainet/models/apertus/ApertusWeightMapper { } public final class sk/ainet/models/apertus/ApertusWeights { - public fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;Z)V - public synthetic fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZLjava/util/Map;Ljava/util/Map;Ljava/util/Map;)V + public synthetic fun (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZLjava/util/Map;Ljava/util/Map;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lsk/ainet/models/apertus/ApertusModelMetadata; public final fun component2 ()Ljava/util/Map; public final fun component3 ()Ljava/util/Map; public final fun component4 ()Z - public final fun copy (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;Z)Lsk/ainet/models/apertus/ApertusWeights; - public static synthetic fun copy$default (Lsk/ainet/models/apertus/ApertusWeights;Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZILjava/lang/Object;)Lsk/ainet/models/apertus/ApertusWeights; + public final fun component5 ()Ljava/util/Map; + public final fun component6 ()Ljava/util/Map; + public final fun component7 ()Ljava/util/Map; + public final fun copy (Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZLjava/util/Map;Ljava/util/Map;Ljava/util/Map;)Lsk/ainet/models/apertus/ApertusWeights; + public static synthetic fun copy$default (Lsk/ainet/models/apertus/ApertusWeights;Lsk/ainet/models/apertus/ApertusModelMetadata;Ljava/util/Map;Ljava/util/Map;ZLjava/util/Map;Ljava/util/Map;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/apertus/ApertusWeights; public fun equals (Ljava/lang/Object;)Z + public final fun getLogicalShapes ()Ljava/util/Map; public final fun getMetadata ()Lsk/ainet/models/apertus/ApertusModelMetadata; public final fun getPreTransposed ()Z + public final fun getQuantBytes ()Ljava/util/Map; + public final fun getQuantTypes ()Ljava/util/Map; public final fun getTensors ()Ljava/util/Map; public final fun getXieluParams ()Ljava/util/Map; public fun hashCode ()I public fun toString ()Ljava/lang/String; } +public final class sk/ainet/models/apertus/ApertusXIELUKt { + public static final fun softplus (F)F + public static final fun xielu ([FLsk/ainet/models/apertus/ApertusXIELUParams;)V +} + public final class sk/ainet/models/apertus/ApertusXIELUParams { public fun (FFFF)V public final fun component1 ()F diff --git a/llm-inference/bert/api/jvm/bert.api b/llm-inference/bert/api/jvm/bert.api index 347d5011..746501fd 100644 --- a/llm-inference/bert/api/jvm/bert.api +++ b/llm-inference/bert/api/jvm/bert.api @@ -170,7 +170,9 @@ public final class sk/ainet/models/bert/HuggingFaceTokenizer : sk/ainet/apps/llm public fun encode (Ljava/lang/String;)[I public final fun encodeWithMetadata (Ljava/lang/String;)Lsk/ainet/models/bert/TokenizerOutput; public final fun encodeWithMetadata (Ljava/lang/String;Ljava/lang/String;)Lsk/ainet/models/bert/TokenizerOutput; - public final fun getVocabSize ()I + public fun getBosTokenId ()I + public fun getEosTokenId ()I + public fun getVocabSize ()I } public final class sk/ainet/models/bert/HuggingFaceTokenizer$Companion { @@ -180,6 +182,35 @@ public final class sk/ainet/models/bert/HuggingFaceTokenizer$Companion { public static synthetic fun fromVocabTxt$default (Lsk/ainet/models/bert/HuggingFaceTokenizer$Companion;Ljava/lang/String;ZILjava/lang/Object;)Lsk/ainet/models/bert/HuggingFaceTokenizer; } +public final class sk/ainet/models/bert/PooledExecutionContext : sk/ainet/context/ExecutionContext { + public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/lang/tensor/scratch/ScratchPool;)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/lang/tensor/scratch/ScratchPool;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun fromByteArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[B)Lsk/ainet/lang/tensor/Tensor; + public fun fromData (Lsk/ainet/lang/tensor/data/TensorData;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor; + public fun fromFloatArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[F)Lsk/ainet/lang/tensor/Tensor; + public fun fromIntArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[I)Lsk/ainet/lang/tensor/Tensor; + public fun full (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;Ljava/lang/Number;)Lsk/ainet/lang/tensor/Tensor; + public fun getExecutionStats ()Lsk/ainet/context/ExecutionStats; + public fun getHooks ()Lsk/ainet/lang/nn/hooks/ForwardHooks; + public fun getInTraining ()Z + public fun getMemoryInfo ()Lsk/ainet/context/MemoryInfo; + public fun getMemoryPlanner ()Lsk/ainet/lang/tensor/storage/MemoryPlanner; + public fun getMemoryTracker ()Lsk/ainet/lang/tensor/storage/MemoryTracker; + public fun getObservers ()Lsk/ainet/context/ExecutionObserverRegistry; + public fun getOps ()Lsk/ainet/lang/tensor/ops/TensorOps; + public fun getPhase ()Lsk/ainet/context/Phase; + public fun getScratch ()Lsk/ainet/lang/tensor/scratch/ScratchPool; + public fun getTensorDataFactory ()Lsk/ainet/lang/tensor/data/TensorDataFactory; + public fun ones (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor; + public fun placeholder (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor; + public fun registerObserver (Lsk/ainet/context/ExecutionObserver;)V + public fun unregisterObserver (Lsk/ainet/context/ExecutionObserver;)V + public fun wrapByteArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[B)Lsk/ainet/lang/tensor/Tensor; + public fun wrapFloatArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[F)Lsk/ainet/lang/tensor/Tensor; + public fun wrapIntArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[I)Lsk/ainet/lang/tensor/Tensor; + public fun zeros (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor; +} + public final class sk/ainet/models/bert/TokenizerOutput { public fun ([I[I[I)V public final fun component1 ()[I diff --git a/llm-inference/gemma/api/jvm/gemma.api b/llm-inference/gemma/api/jvm/gemma.api index f549609f..83302d75 100644 --- a/llm-inference/gemma/api/jvm/gemma.api +++ b/llm-inference/gemma/api/jvm/gemma.api @@ -541,8 +541,8 @@ public final class sk/ainet/models/gemma/Gemma4ModelMetadata { public static final field DEFAULT_HEAD_DIM I public static final field DEFAULT_KV_SHARED_LAYERS I public static final field DEFAULT_SLIDING_WINDOW I - public fun (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;III)V - public synthetic fun (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIF)V + public synthetic fun (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIFILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; public final fun component10 ()I public final fun component11 ()I @@ -557,6 +557,7 @@ public final class sk/ainet/models/gemma/Gemma4ModelMetadata { public final fun component2 ()I public final fun component20 ()I public final fun component21 ()I + public final fun component22 ()F public final fun component3 ()I public final fun component4 ()I public final fun component5 ()I @@ -564,8 +565,8 @@ public final class sk/ainet/models/gemma/Gemma4ModelMetadata { public final fun component7 ()I public final fun component8 ()I public final fun component9 ()I - public final fun copy (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;III)Lsk/ainet/models/gemma/Gemma4ModelMetadata; - public static synthetic fun copy$default (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIILjava/lang/Object;)Lsk/ainet/models/gemma/Gemma4ModelMetadata; + public final fun copy (Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIF)Lsk/ainet/models/gemma/Gemma4ModelMetadata; + public static synthetic fun copy$default (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/lang/String;IIIIIIIIIIILjava/util/List;Lsk/ainet/models/gemma/Gemma4RopeConfig;Lsk/ainet/models/gemma/Gemma4RopeConfig;IILjava/util/List;IIIFILjava/lang/Object;)Lsk/ainet/models/gemma/Gemma4ModelMetadata; public fun equals (Ljava/lang/Object;)Z public final fun getArchitecture ()Ljava/lang/String; public final fun getBlockCount ()I @@ -575,6 +576,7 @@ public final class sk/ainet/models/gemma/Gemma4ModelMetadata { public final fun getEffectiveWindow (I)I public final fun getEmbeddingLength ()I public final fun getEosTokenId ()I + public final fun getFinalLogitSoftcapping ()F public final fun getGlobalHeadDim ()I public final fun getHeadCount ()I public final fun getHeadDim ()I @@ -622,13 +624,6 @@ public final class sk/ainet/models/gemma/Gemma4RopeConfig { public fun toString ()Ljava/lang/String; } -public final class sk/ainet/models/gemma/Gemma4Runtime : sk/ainet/apps/llm/DecoderRuntime { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/gemma/Gemma4RuntimeWeights;Lsk/ainet/models/gemma/AttentionBackend;Lkotlin/reflect/KClass;Lsk/ainet/models/gemma/Gemma4Config;FLkotlin/random/Random;)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/gemma/Gemma4RuntimeWeights;Lsk/ainet/models/gemma/AttentionBackend;Lkotlin/reflect/KClass;Lsk/ainet/models/gemma/Gemma4Config;FLkotlin/random/Random;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun getCurrentPosition ()I - public final fun getWeights ()Lsk/ainet/models/gemma/Gemma4RuntimeWeights; -} - public final class sk/ainet/models/gemma/Gemma4RuntimeWeights { public fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/Map;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)V public synthetic fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/List;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Ljava/util/Map;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;ILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -676,12 +671,21 @@ public final class sk/ainet/models/gemma/Gemma4RuntimeWeightsKt { public static synthetic fun loadGemma4RuntimeWeightsStreaming$default (Lsk/ainet/context/ExecutionContext;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZLkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public final class sk/ainet/models/gemma/Gemma4SafeTensorsMappedPle { + public static final field INSTANCE Lsk/ainet/models/gemma/Gemma4SafeTensorsMappedPle; + public final fun injectIfMissing (Lsk/ainet/models/gemma/Gemma4Weights;Ljava/lang/String;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;)Lsk/ainet/models/gemma/Gemma4Weights; +} + public final class sk/ainet/models/gemma/Gemma4SafeTensorsWeightLoader { public static final field HF_EMBED_TOKENS Ljava/lang/String; + public static final field HF_EMBED_TOKENS_PER_LAYER Ljava/lang/String; public static final field HF_OUTPUT_NORM Ljava/lang/String; public static final field HF_PER_LAYER_MODEL_PROJ Ljava/lang/String; + public static final field HF_PER_LAYER_MODEL_PROJECTION Ljava/lang/String; + public static final field HF_PER_LAYER_PROJECTION_NORM Ljava/lang/String; public static final field HF_PER_LAYER_PROJ_NORM Ljava/lang/String; public static final field HF_PER_LAYER_TOKEN_EMBD Ljava/lang/String; + public static final field MAX_BYTES_PER_TENSOR J public fun (Ljava/lang/String;)V public final fun loadToMap (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } @@ -711,9 +715,15 @@ public final class sk/ainet/models/gemma/Gemma4TensorNames { public final fun ffnGate (I)Ljava/lang/String; public final fun ffnUp (I)Ljava/lang/String; public final fun inputLayernorm (I)Ljava/lang/String; + public final fun layerOutputScale (I)Ljava/lang/String; public final fun perLayerInput (I)Ljava/lang/String; public final fun perLayerOutput (I)Ljava/lang/String; + public final fun pleInpGate (I)Ljava/lang/String; + public final fun plePostNorm (I)Ljava/lang/String; + public final fun pleProj (I)Ljava/lang/String; public final fun postAttentionLayernorm (I)Ljava/lang/String; + public final fun postAttentionNorm (I)Ljava/lang/String; + public final fun postFfwNorm (I)Ljava/lang/String; } public final class sk/ainet/models/gemma/Gemma4WeightLoader { @@ -733,14 +743,16 @@ public final class sk/ainet/models/gemma/Gemma4WeightMapper { } public final class sk/ainet/models/gemma/Gemma4Weights { - public fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;)V - public synthetic fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;Ljava/util/Map;)V + public synthetic fun (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lsk/ainet/models/gemma/Gemma4ModelMetadata; public final fun component2 ()Ljava/util/Map; public final fun component3 ()Ljava/util/Map; - public final fun copy (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;)Lsk/ainet/models/gemma/Gemma4Weights; - public static synthetic fun copy$default (Lsk/ainet/models/gemma/Gemma4Weights;Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/gemma/Gemma4Weights; + public final fun component4 ()Ljava/util/Map; + public final fun copy (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;Ljava/util/Map;)Lsk/ainet/models/gemma/Gemma4Weights; + public static synthetic fun copy$default (Lsk/ainet/models/gemma/Gemma4Weights;Lsk/ainet/models/gemma/Gemma4ModelMetadata;Ljava/util/Map;Ljava/util/Map;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/gemma/Gemma4Weights; public fun equals (Ljava/lang/Object;)Z + public final fun getLogicalShapes ()Ljava/util/Map; public final fun getMetadata ()Lsk/ainet/models/gemma/Gemma4ModelMetadata; public final fun getQuantTypes ()Ljava/util/Map; public final fun getTensors ()Ljava/util/Map; @@ -752,9 +764,29 @@ public final class sk/ainet/models/gemma/GemmaMemSegConverterKt { public static final fun convertGemmaWeightsToMemSeg (Lsk/ainet/models/gemma/Gemma4Weights;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;)Lsk/ainet/models/gemma/Gemma4Weights; } +public final class sk/ainet/models/gemma/GemmaModel : sk/ainet/lang/nn/Module { + public static final field Companion Lsk/ainet/models/gemma/GemmaModel$Companion; + public fun (Lsk/ainet/lang/nn/layers/EmbeddingAdapter;Lsk/ainet/models/gemma/PerLayerEmbedding;Ljava/util/List;Lsk/ainet/lang/nn/normalization/RMSNormalization;Lsk/ainet/lang/nn/transformer/VoidDense;Lkotlin/reflect/KClass;FFLjava/lang/String;)V + public synthetic fun (Lsk/ainet/lang/nn/layers/EmbeddingAdapter;Lsk/ainet/models/gemma/PerLayerEmbedding;Ljava/util/List;Lsk/ainet/lang/nn/normalization/RMSNormalization;Lsk/ainet/lang/nn/transformer/VoidDense;Lkotlin/reflect/KClass;FFLjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getBlocks ()Ljava/util/List; + public final fun getDtype ()Lkotlin/reflect/KClass; + public final fun getEmbedScale ()F + public final fun getFinalLogitSoftcapping ()F + public final fun getLmHead ()Lsk/ainet/lang/nn/transformer/VoidDense; + public fun getModules ()Ljava/util/List; + public fun getName ()Ljava/lang/String; + public final fun getOutputNorm ()Lsk/ainet/lang/nn/normalization/RMSNormalization; + public final fun getPle ()Lsk/ainet/models/gemma/PerLayerEmbedding; + public final fun getTokenEmbedding ()Lsk/ainet/lang/nn/layers/EmbeddingAdapter; +} + +public final class sk/ainet/models/gemma/GemmaModel$Companion { + public final fun findHookIn (Ljava/util/List;)Lsk/ainet/models/gemma/PerLayerInputBlockHook; +} + public final class sk/ainet/models/gemma/GemmaNetworkDefKt { - public static final fun gemmaNetwork (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lkotlin/reflect/KClass;I)Lsk/ainet/lang/nn/Module; - public static synthetic fun gemmaNetwork$default (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lkotlin/reflect/KClass;IILjava/lang/Object;)Lsk/ainet/lang/nn/Module; + public static final fun gemmaNetwork (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lkotlin/reflect/KClass;IZZZZZ)Lsk/ainet/lang/nn/Module; + public static synthetic fun gemmaNetwork$default (Lsk/ainet/models/gemma/Gemma4ModelMetadata;Lkotlin/reflect/KClass;IZZZZZILjava/lang/Object;)Lsk/ainet/lang/nn/Module; } public final class sk/ainet/models/gemma/GemmaNetworkLoader { @@ -831,6 +863,20 @@ public final class sk/ainet/models/gemma/GemmaNetworkLoaderKt { public static final fun applyWeightsToNetworkNonReified (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/gemma/Gemma4Weights;Lkotlin/reflect/KClass;Z)Lsk/ainet/lang/nn/Module; } +public final class sk/ainet/models/gemma/GemmaPerLayerTokenEmbedTensorData : sk/ainet/lang/tensor/data/TensorData, sk/ainet/models/gemma/RowDequantSource { + public fun (Lsk/ainet/lang/tensor/Shape;Lsk/ainet/io/gguf/GGMLQuantizationType;[B)V + public fun copyToFloatArray ()[F + public fun dequantRow (I)[F + public fun get ([I)Ljava/lang/Float; + public synthetic fun get ([I)Ljava/lang/Object; + public final fun getBytesPerRow ()I + public final fun getPackedBytes ()[B + public final fun getQuantType ()Lsk/ainet/io/gguf/GGMLQuantizationType; + public fun getShape ()Lsk/ainet/lang/tensor/Shape; + public fun set ([IF)V + public synthetic fun set ([ILjava/lang/Object;)V +} + public final class sk/ainet/models/gemma/HeapGemma3nKvCache : sk/ainet/models/gemma/Gemma3nKvCache { public static final field Companion Lsk/ainet/models/gemma/HeapGemma3nKvCache$Companion; public fun (IIILjava/util/List;I)V @@ -875,6 +921,50 @@ public final class sk/ainet/models/gemma/LayerType : java/lang/Enum { public static fun values ()[Lsk/ainet/models/gemma/LayerType; } +public final class sk/ainet/models/gemma/PerLayerEmbedding : sk/ainet/lang/nn/Module, sk/ainet/lang/nn/topology/ModuleParameters { + public fun (IIIIFLjava/lang/String;)V + public synthetic fun (IIIIFLjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun compute (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor; + public final fun getHiddenSize ()I + public fun getModules ()Ljava/util/List; + public fun getName ()Ljava/lang/String; + public final fun getNumLayers ()I + public fun getParams ()Ljava/util/List; + public final fun getPerLayerDim ()I + public final fun getProjectionNorm ()Lsk/ainet/lang/nn/normalization/RMSNormalization; + public final fun getRmsEps ()F + public final fun getVocabSize ()I +} + +public final class sk/ainet/models/gemma/PerLayerInputBlockHook : sk/ainet/lang/nn/Module, sk/ainet/lang/nn/topology/ModuleParameters { + public fun (IIZLjava/lang/String;)V + public synthetic fun (IIZLjava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getHiddenSize ()I + public fun getModules ()Ljava/util/List; + public fun getName ()Ljava/lang/String; + public fun getParams ()Ljava/util/List; + public final fun getPerLayerDim ()I + public final fun getPerLayerInput ()Lsk/ainet/lang/tensor/Tensor; + public final fun getPostNorm ()Lsk/ainet/lang/nn/normalization/RMSNormalization; + public final fun getSideChannelOnly ()Z + public final fun setPerLayerInput (Lsk/ainet/lang/tensor/Tensor;)V +} + +public abstract interface class sk/ainet/models/gemma/RowDequantSource { + public abstract fun dequantRow (I)[F +} + +public final class sk/ainet/models/gemma/SafeTensorsPerLayerTokenEmbedTensorData : sk/ainet/lang/tensor/data/TensorData, sk/ainet/models/gemma/RowDequantSource { + public fun (Lsk/ainet/lang/tensor/Shape;Ljava/lang/foreign/MemorySegment;Ljava/lang/String;)V + public fun copyToFloatArray ()[F + public fun dequantRow (I)[F + public fun get ([I)Ljava/lang/Float; + public synthetic fun get ([I)Ljava/lang/Object; + public fun getShape ()Lsk/ainet/lang/tensor/Shape; + public fun set ([IF)V + public synthetic fun set ([ILjava/lang/Object;)V +} + public final class sk/ainet/models/gemma/multimodal/VisionEncoder { public static final field Companion Lsk/ainet/models/gemma/multimodal/VisionEncoder$Companion; public static final field DEFAULT_IMAGE_SIZE I diff --git a/llm-inference/llama/api/jvm/llama.api b/llm-inference/llama/api/jvm/llama.api index 1efc44a6..925fb710 100644 --- a/llm-inference/llama/api/jvm/llama.api +++ b/llm-inference/llama/api/jvm/llama.api @@ -8,6 +8,53 @@ public final class sk/ainet/models/llama/AttentionBackend$DefaultImpls { public static fun batchAttention (Lsk/ainet/models/llama/AttentionBackend;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; } +public final class sk/ainet/models/llama/DecoderGgufMemSegConverter { + public static final field INSTANCE Lsk/ainet/models/llama/DecoderGgufMemSegConverter; + public final fun convert (Lsk/ainet/models/llama/DecoderGgufWeights;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;)Lsk/ainet/models/llama/DecoderGgufWeights; +} + +public final class sk/ainet/models/llama/DecoderGgufMemSegConverterKt { + public static final fun loadDecoderGgufWeightsNative (Lkotlin/jvm/functions/Function0;Ljava/util/Set;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class sk/ainet/models/llama/DecoderGgufWeightLoader { + public static final field Dequant Lsk/ainet/models/llama/DecoderGgufWeightLoader$Dequant; + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Ljava/util/Set;)V + public synthetic fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Ljava/util/Set;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lkotlin/jvm/functions/Function0;ZLsk/ainet/io/model/QuantPolicy;Ljava/util/Set;)V + public synthetic fun (Lkotlin/jvm/functions/Function0;ZLsk/ainet/io/model/QuantPolicy;Ljava/util/Set;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun load (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadStreaming (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadToMap (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadToMapStreaming (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class sk/ainet/models/llama/DecoderGgufWeightLoader$Dequant { +} + +public final class sk/ainet/models/llama/DecoderGgufWeights { + public fun (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;)V + public synthetic fun (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun component2 ()Ljava/util/Map; + public final fun component3 ()Ljava/util/Map; + public final fun copy (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;)Lsk/ainet/models/llama/DecoderGgufWeights; + public static synthetic fun copy$default (Lsk/ainet/models/llama/DecoderGgufWeights;Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/llama/DecoderGgufWeights; + public fun equals (Ljava/lang/Object;)Z + public final fun getMetadata ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getQuantTypes ()Ljava/util/Map; + public final fun getTensors ()Ljava/util/Map; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/llama/DecoderSafeTensorsLoader { + public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun load (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaRuntimeWeights; + public final fun loadToMap (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/DecoderGgufWeights; +} + public abstract interface class sk/ainet/models/llama/GraphAccelerator { public abstract fun close ()V public abstract fun runFFN (ILsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor; @@ -88,7 +135,7 @@ public final class sk/ainet/models/llama/LlamaLayerWeights { public fun toString ()Ljava/lang/String; } -public final class sk/ainet/models/llama/LlamaModelMetadata { +public final class sk/ainet/models/llama/LlamaModelMetadata : sk/ainet/lang/nn/dsl/decoder/DecoderModelMetadata { public fun (Ljava/lang/String;IIIIIILjava/lang/Integer;IFFII)V public synthetic fun (Ljava/lang/String;IIIIIILjava/lang/Integer;IFFIIILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; @@ -108,18 +155,18 @@ public final class sk/ainet/models/llama/LlamaModelMetadata { public static synthetic fun copy$default (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/lang/String;IIIIIILjava/lang/Integer;IFFIIILjava/lang/Object;)Lsk/ainet/models/llama/LlamaModelMetadata; public fun equals (Ljava/lang/Object;)Z public final fun getArchitecture ()Ljava/lang/String; - public final fun getBlockCount ()I - public final fun getBosTokenId ()I - public final fun getContextLength ()I - public final fun getEmbeddingLength ()I - public final fun getEosTokenId ()I - public final fun getFeedForwardLength ()I - public final fun getHeadCount ()I - public final fun getKvHeadCount ()I - public final fun getRmsNormEps ()F - public final fun getRopeDimensionCount ()Ljava/lang/Integer; - public final fun getRopeFreqBase ()F - public final fun getVocabSize ()I + public fun getBlockCount ()I + public fun getBosTokenId ()I + public fun getContextLength ()I + public fun getEmbeddingLength ()I + public fun getEosTokenId ()I + public fun getFeedForwardLength ()I + public fun getHeadCount ()I + public fun getKvHeadCount ()I + public fun getRmsNormEps ()F + public fun getRopeDimensionCount ()Ljava/lang/Integer; + public fun getRopeFreqBase ()F + public fun getVocabSize ()I public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -171,12 +218,12 @@ public final class sk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Gguf } public final class sk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded : sk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider { - public fun (Lsk/ainet/models/llama/LlamaWeights;)V - public final fun component1 ()Lsk/ainet/models/llama/LlamaWeights; - public final fun copy (Lsk/ainet/models/llama/LlamaWeights;)Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded; - public static synthetic fun copy$default (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded;Lsk/ainet/models/llama/LlamaWeights;ILjava/lang/Object;)Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded; + public fun (Lsk/ainet/models/llama/DecoderGgufWeights;)V + public final fun component1 ()Lsk/ainet/models/llama/DecoderGgufWeights; + public final fun copy (Lsk/ainet/models/llama/DecoderGgufWeights;)Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded; + public static synthetic fun copy$default (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded;Lsk/ainet/models/llama/DecoderGgufWeights;ILjava/lang/Object;)Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider$Preloaded; public fun equals (Ljava/lang/Object;)Z - public final fun getWeights ()Lsk/ainet/models/llama/LlamaWeights; + public final fun getWeights ()Lsk/ainet/models/llama/DecoderGgufWeights; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -213,6 +260,7 @@ public abstract interface class sk/ainet/models/llama/LlamaRuntimeInterface : sk } public final class sk/ainet/models/llama/LlamaRuntimeInterface$DefaultImpls { + public static fun forwardBatched (Lsk/ainet/models/llama/LlamaRuntimeInterface;[I)Lsk/ainet/lang/tensor/Tensor; public static synthetic fun generate$default (Lsk/ainet/models/llama/LlamaRuntimeInterface;[IIFLkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } @@ -257,13 +305,6 @@ public final class sk/ainet/models/llama/LlamaRuntimeWeightsKt { public static synthetic fun loadLlamaRuntimeWeightsStreaming$default (Lsk/ainet/context/ExecutionContext;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZLkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } -public final class sk/ainet/models/llama/LlamaSafeTensorsLoader { - public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun load (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaRuntimeWeights; - public final fun loadToMap (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaWeights; -} - public final class sk/ainet/models/llama/LlamaTensorNames { public static final field INSTANCE Lsk/ainet/models/llama/LlamaTensorNames; public static final field OUTPUT_NORM Ljava/lang/String; @@ -284,40 +325,9 @@ public final class sk/ainet/models/llama/LlamaTensorNames { public final fun ffnUp (I)Ljava/lang/String; } -public final class sk/ainet/models/llama/LlamaWeightLoader { - public static final field Dequant Lsk/ainet/models/llama/LlamaWeightLoader$Dequant; - public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Ljava/util/Set;)V - public synthetic fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Ljava/util/Set;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun (Lkotlin/jvm/functions/Function0;ZLsk/ainet/io/model/QuantPolicy;Ljava/util/Set;)V - public synthetic fun (Lkotlin/jvm/functions/Function0;ZLsk/ainet/io/model/QuantPolicy;Ljava/util/Set;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun load (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun loadStreaming (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun loadToMap (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun loadToMapStreaming (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; -} - -public final class sk/ainet/models/llama/LlamaWeightLoader$Dequant { -} - public final class sk/ainet/models/llama/LlamaWeightMapper { public static final field INSTANCE Lsk/ainet/models/llama/LlamaWeightMapper; - public final fun map (Lsk/ainet/models/llama/LlamaWeights;)Lsk/ainet/models/llama/LlamaRuntimeWeights; -} - -public final class sk/ainet/models/llama/LlamaWeights { - public fun (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;)V - public synthetic fun (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun component1 ()Lsk/ainet/models/llama/LlamaModelMetadata; - public final fun component2 ()Ljava/util/Map; - public final fun component3 ()Ljava/util/Map; - public final fun copy (Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;)Lsk/ainet/models/llama/LlamaWeights; - public static synthetic fun copy$default (Lsk/ainet/models/llama/LlamaWeights;Lsk/ainet/models/llama/LlamaModelMetadata;Ljava/util/Map;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/llama/LlamaWeights; - public fun equals (Ljava/lang/Object;)Z - public final fun getMetadata ()Lsk/ainet/models/llama/LlamaModelMetadata; - public final fun getQuantTypes ()Ljava/util/Map; - public final fun getTensors ()Ljava/util/Map; - public fun hashCode ()I - public fun toString ()Ljava/lang/String; + public final fun map (Lsk/ainet/models/llama/DecoderGgufWeights;)Lsk/ainet/models/llama/LlamaRuntimeWeights; } public final class sk/ainet/models/llama/MemSegWeightConverter { @@ -329,7 +339,7 @@ public final class sk/ainet/models/llama/MmapLlamaLoader : java/lang/AutoCloseab public fun (Ljava/nio/file/Path;Ljava/util/Set;)V public synthetic fun (Ljava/nio/file/Path;Ljava/util/Set;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close ()V - public final fun loadToMap (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;)Lsk/ainet/models/llama/LlamaWeights; + public final fun loadToMap (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;)Lsk/ainet/models/llama/DecoderGgufWeights; } public final class sk/ainet/models/llama/QuantizedTensorFactory { diff --git a/llm-inference/llama/src/jvmMain/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverter.kt b/llm-inference/llama/src/jvmMain/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverter.kt index 074c6cf1..d5a60eb1 100644 --- a/llm-inference/llama/src/jvmMain/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverter.kt +++ b/llm-inference/llama/src/jvmMain/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverter.kt @@ -5,6 +5,7 @@ import sk.ainet.io.RandomAccessSource import sk.ainet.io.gguf.GGMLQuantizationType import sk.ainet.io.gguf.dequant.DequantOps import sk.ainet.io.model.QuantPolicy +import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.tensor.data.IntArrayTensorData import sk.ainet.lang.tensor.data.Q4MemorySegmentTensorData @@ -23,16 +24,30 @@ import java.lang.foreign.Arena * * Behavior per quant type: * - **Q4_0 / Q8_0** → wrapped as [Q4MemorySegmentTensorData] / - * [Q8MemorySegmentTensorData]. Upstream `DefaultCpuOpsJvm.matmul` detects - * their markers and dispatches SIMD quant kernels at forward time. + * [Q8MemorySegmentTensorData] with the **logical** matrix shape derived + * from metadata. Upstream `DefaultCpuOpsJvm.matmul` and `transpose` + * detect the markers and dispatch quant-aware kernels at forward time. * - **Q4_K / Q5_K / Q6_K** → dequantized to FP32. The packed K-quant kernels * are MemSeg-only on a hot path the DSL doesn't yet route through, so this * trades memory for correctness. Same trade-off the legacy converter * makes for K-quants. + * - **token_embd.weight** → always dequantized to FP32 regardless of quant + * type. The Embedding layer consumes this via `gather`, not matmul, so it + * needs real floats with the logical 2D shape — packed quant bytes would + * be misread as FP32 values, and the loader's intermediate Int8 wrapper + * stores a 1D byte-count shape that `gather` rejects. * - **FP32 (no entry in `quantTypes`)** → passed through unchanged. * - **Other quant types** → warning logged, passed through (will fail later * if the model actually hits them via matmul). * + * Why logical shape matters here: the loader stores raw quant bytes via + * `ctx.fromByteArray(Shape(bytes.size), Int8, bytes)` — a 1D byte-count + * shape, because the Int8 factory requires `shape.volume == bytes.size` + * and packed Q4/Q8 have more bytes than logical floats. The Q4/Q8 MemSeg + * tensor data classes, in contrast, hold the logical shape independently + * from the byte buffer, which is what `gather` / `transpose` / `matmul` + * need. + * * Unlike the legacy [MemSegWeightConverter], this one does NOT pre-transpose * weights to `[in, out]`. The DSL's [sk.ainet.lang.nn.transformer.linearProject] * always calls `ops.transpose(weight)` at forward time; for Q4/Q8 MemSeg @@ -48,7 +63,8 @@ public object DecoderGgufMemSegConverter { /** * Return a copy of [weights] with Q4_0/Q8_0 tensors wrapped as MemSeg - * variants and K-quants dequantized to FP32. No-op if [weights] has no + * variants with logical shapes, K-quants dequantized to FP32, and the + * token embedding always dequantized. No-op if [weights] has no * quantized tensors. */ public fun convert( @@ -58,6 +74,13 @@ public object DecoderGgufMemSegConverter { ): DecoderGgufWeights { if (weights.quantTypes.isEmpty()) return weights + val meta = weights.metadata + val dim = meta.embeddingLength + val headSize = dim / meta.headCount + val kvDim = meta.kvHeadCount * headSize + val ffnDim = meta.feedForwardLength + val vocab = meta.vocabSize + val newTensors = LinkedHashMap>(weights.tensors.size) for ((name, tensor) in weights.tensors) { val quantType = weights.quantTypes[name] @@ -65,7 +88,16 @@ public object DecoderGgufMemSegConverter { newTensors[name] = tensor continue } - newTensors[name] = convertOne(name, tensor, quantType, ctx, arena) + val logicalShape = logicalShapeFor(name, dim, kvDim, ffnDim, vocab) + if (logicalShape == null) { + println( + "WARNING: DecoderGgufMemSegConverter: no logical shape for '$name'; " + + "passing through quantized — forward pass may fail.", + ) + newTensors[name] = tensor + continue + } + newTensors[name] = convertOne(name, tensor, quantType, logicalShape, ctx, arena) } // Drop quantTypes from the result — tensors are now either packed @@ -75,37 +107,64 @@ public object DecoderGgufMemSegConverter { return weights.copy(tensors = newTensors, quantTypes = emptyMap()) } + private fun logicalShapeFor( + name: String, + dim: Int, + kvDim: Int, + ffnDim: Int, + vocab: Int, + ): Shape? = when { + name == LlamaTensorNames.TOKEN_EMBEDDINGS -> Shape(vocab, dim) + name == LlamaTensorNames.OUTPUT_WEIGHT -> Shape(vocab, dim) + name.endsWith(".attn_q.weight") -> Shape(dim, dim) + name.endsWith(".attn_k.weight") -> Shape(kvDim, dim) + name.endsWith(".attn_v.weight") -> Shape(kvDim, dim) + name.endsWith(".attn_output.weight") -> Shape(dim, dim) + name.endsWith(".ffn_gate.weight") -> Shape(ffnDim, dim) + name.endsWith(".ffn_up.weight") -> Shape(ffnDim, dim) + name.endsWith(".ffn_down.weight") -> Shape(dim, ffnDim) + else -> null + } + private fun convertOne( name: String, tensor: Tensor, quantType: GGMLQuantizationType, + logicalShape: Shape, ctx: ExecutionContext, arena: Arena, ): Tensor { val bytes = extractBytes(tensor.data) - val shape = tensor.shape + + // token_embd.weight (and tied output.weight, which holds the same + // bytes) is consumed by Embedding.gather, not matmul. Packed quant + // bytes can't be read by gather as floats, so dequantize. + if (name == LlamaTensorNames.TOKEN_EMBEDDINGS) { + val floats = DequantOps.dequantFromBytes(bytes, quantType, logicalShape.volume) + return ctx.fromFloatArray(logicalShape, FP32::class, floats) + } return when (quantType) { GGMLQuantizationType.Q4_0 -> { - val newData = Q4MemorySegmentTensorData.fromRawBytes(shape, bytes, arena) + val newData = Q4MemorySegmentTensorData.fromRawBytes(logicalShape, bytes, arena) @Suppress("UNCHECKED_CAST") ctx.fromData(newData as TensorData, FP32::class) } GGMLQuantizationType.Q8_0 -> { - val newData = Q8MemorySegmentTensorData.fromRawBytes(shape, bytes, arena) + val newData = Q8MemorySegmentTensorData.fromRawBytes(logicalShape, bytes, arena) @Suppress("UNCHECKED_CAST") ctx.fromData(newData as TensorData, FP32::class) } GGMLQuantizationType.Q4_K, GGMLQuantizationType.Q5_K, GGMLQuantizationType.Q6_K -> { - val floats = DequantOps.dequantFromBytes(bytes, quantType, shape.volume) - ctx.fromFloatArray(shape, FP32::class, floats) + val floats = DequantOps.dequantFromBytes(bytes, quantType, logicalShape.volume) + ctx.fromFloatArray(logicalShape, FP32::class, floats) } else -> { println( "WARNING: DecoderGgufMemSegConverter: unsupported quant type $quantType for '$name'; " + - "passing through unchanged. Forward pass may fail at matmul." + "passing through unchanged. Forward pass may fail at matmul.", ) tensor } diff --git a/llm-inference/qwen/api/jvm/qwen.api b/llm-inference/qwen/api/jvm/qwen.api index a1a52d16..5c26e1b6 100644 --- a/llm-inference/qwen/api/jvm/qwen.api +++ b/llm-inference/qwen/api/jvm/qwen.api @@ -25,3 +25,83 @@ public final class sk/ainet/models/qwen/QwenHfTensorNameMapper { public final fun toCanonical (Ljava/lang/String;)Ljava/lang/String; } +public final class sk/ainet/models/qwen/QwenNetworkLoader { + public static final field Companion Lsk/ainet/models/qwen/QwenNetworkLoader$Companion; + public fun (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider;Z)V + public synthetic fun (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getDebug ()Z + public final fun getWeightsProvider ()Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoader$Companion { + public final fun fromGguf (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Z)Lsk/ainet/models/qwen/QwenNetworkLoader; + public static synthetic fun fromGguf$default (Lsk/ainet/models/qwen/QwenNetworkLoader$Companion;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader; + public final fun fromGgufRandomAccess (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Z)Lsk/ainet/models/qwen/QwenNetworkLoader; + public static synthetic fun fromGgufRandomAccess$default (Lsk/ainet/models/qwen/QwenNetworkLoader$Companion;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader; + public final fun fromSafeTensors (Lsk/ainet/models/llama/LlamaModelMetadata;Lkotlin/jvm/functions/Function0;ZZ)Lsk/ainet/models/qwen/QwenNetworkLoader; + public static synthetic fun fromSafeTensors$default (Lsk/ainet/models/qwen/QwenNetworkLoader$Companion;Lsk/ainet/models/llama/LlamaModelMetadata;Lkotlin/jvm/functions/Function0;ZZILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader; +} + +public abstract interface class sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider { +} + +public final class sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufRandomAccess : sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/io/model/QuantPolicy; + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufRandomAccess; + public static synthetic fun copy$default (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufRandomAccess;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufRandomAccess; + public fun equals (Ljava/lang/Object;)Z + public final fun getQuantPolicy ()Lsk/ainet/io/model/QuantPolicy; + public final fun getRandomAccessProvider ()Lkotlin/jvm/functions/Function0; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufSource : sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/io/model/QuantPolicy; + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufSource; + public static synthetic fun copy$default (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufSource;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$GgufSource; + public fun equals (Ljava/lang/Object;)Z + public final fun getQuantPolicy ()Lsk/ainet/io/model/QuantPolicy; + public final fun getSourceProvider ()Lkotlin/jvm/functions/Function0; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$Preloaded : sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider { + public fun (Lsk/ainet/models/llama/DecoderGgufWeights;)V + public final fun component1 ()Lsk/ainet/models/llama/DecoderGgufWeights; + public final fun copy (Lsk/ainet/models/llama/DecoderGgufWeights;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$Preloaded; + public static synthetic fun copy$default (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$Preloaded;Lsk/ainet/models/llama/DecoderGgufWeights;ILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$Preloaded; + public fun equals (Ljava/lang/Object;)Z + public final fun getWeights ()Lsk/ainet/models/llama/DecoderGgufWeights; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$SafeTensors : sk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun component3 ()Z + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;Z)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$SafeTensors; + public static synthetic fun copy$default (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$SafeTensors;Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;ZILjava/lang/Object;)Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider$SafeTensors; + public fun equals (Ljava/lang/Object;)Z + public final fun getMetadata ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getRandomAccessProvider ()Lkotlin/jvm/functions/Function0; + public final fun getTiedEmbeddings ()Z + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoaderJvmKt { + public static final fun fromGgufNative (Lsk/ainet/models/qwen/QwenNetworkLoader$Companion;Lkotlin/jvm/functions/Function0;Lsk/ainet/context/ExecutionContext;Ljava/lang/foreign/Arena;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class sk/ainet/models/qwen/QwenNetworkLoaderKt { + public static final fun getQWEN_ARCHITECTURES ()Ljava/util/Set; +} + diff --git a/llm-inference/voxtral/api/jvm/voxtral.api b/llm-inference/voxtral/api/jvm/voxtral.api new file mode 100644 index 00000000..5506154d --- /dev/null +++ b/llm-inference/voxtral/api/jvm/voxtral.api @@ -0,0 +1,378 @@ +public final class sk/ainet/models/voxtral/TekkenTokenizerAdapter : sk/ainet/apps/llm/Tokenizer { + public static final field Companion Lsk/ainet/models/voxtral/TekkenTokenizerAdapter$Companion; + public fun (Lsk/ainet/io/tokenizer/TekkenTokenizer;)V + public fun decode (I)Ljava/lang/String; + public fun decode ([I)Ljava/lang/String; + public fun encode (Ljava/lang/String;)[I + public fun getBosTokenId ()I + public fun getEosTokenId ()I + public fun getVocabSize ()I +} + +public final class sk/ainet/models/voxtral/TekkenTokenizerAdapter$Companion { + public final fun fromJson (Ljava/lang/String;)Lsk/ainet/apps/llm/Tokenizer; +} + +public final class sk/ainet/models/voxtral/VoxtralAcousticRuntime { + public fun (Ljava/util/Map;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIIIIIIFF)V + public synthetic fun (Ljava/util/Map;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIIIIIIFFILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun generate (Lsk/ainet/lang/tensor/Tensor;ILjava/lang/String;Lkotlin/random/Random;)[I + public static synthetic fun generate$default (Lsk/ainet/models/voxtral/VoxtralAcousticRuntime;Lsk/ainet/lang/tensor/Tensor;ILjava/lang/String;Lkotlin/random/Random;ILjava/lang/Object;)[I +} + +public final class sk/ainet/models/voxtral/VoxtralAudioConfig { + public fun ()V + public fun (IIIIIFLjava/lang/String;IILjava/lang/String;IIII)V + public synthetic fun (IIIIIFLjava/lang/String;IILjava/lang/String;IIIIILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()I + public final fun component10 ()Ljava/lang/String; + public final fun component11 ()I + public final fun component12 ()I + public final fun component13 ()I + public final fun component14 ()I + public final fun component2 ()I + public final fun component3 ()I + public final fun component4 ()I + public final fun component5 ()I + public final fun component6 ()F + public final fun component7 ()Ljava/lang/String; + public final fun component8 ()I + public final fun component9 ()I + public final fun copy (IIIIIFLjava/lang/String;IILjava/lang/String;IIII)Lsk/ainet/models/voxtral/VoxtralAudioConfig; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralAudioConfig;IIIIIFLjava/lang/String;IILjava/lang/String;IIIIILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralAudioConfig; + public fun equals (Ljava/lang/Object;)Z + public final fun getAcousticCodebookSize ()I + public final fun getAudioTokenId ()I + public final fun getBeginAudioTokenId ()I + public final fun getBosTokenId ()I + public final fun getCodebookPattern ()Ljava/lang/String; + public final fun getConditionDroppedTokenId ()I + public final fun getFrameRate ()F + public final fun getInputEmbeddingConcatType ()Ljava/lang/String; + public final fun getInterleaveAudioTokensPerSegment ()I + public final fun getInterleaveTextTokensPerSegment ()I + public final fun getNAcousticCodebooks ()I + public final fun getNumCodebooks ()I + public final fun getSamplingRate ()I + public final fun getSemanticCodebookSize ()I + public final fun getTotalCodebooks ()I + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralBackboneRuntime { + public fun (Lsk/ainet/lang/nn/Module;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;ILkotlin/random/Random;)V + public synthetic fun (Lsk/ainet/lang/nn/Module;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;ILkotlin/random/Random;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun forward (I)Lsk/ainet/lang/tensor/Tensor; + public final fun forwardEmbedding ([F)Lsk/ainet/lang/tensor/Tensor; + public final fun generate ([IIFLsk/ainet/models/voxtral/VoxtralVoice;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun generate$default (Lsk/ainet/models/voxtral/VoxtralBackboneRuntime;[IIFLsk/ainet/models/voxtral/VoxtralVoice;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public final fun getBos ()I + public final fun getPosition ()I + public final fun lastHiddenStates ()Lsk/ainet/lang/tensor/Tensor; + public final fun reset ()V + public final fun sample (Lsk/ainet/lang/tensor/Tensor;F)I +} + +public final class sk/ainet/models/voxtral/VoxtralCodecMetadata { + public fun ()V + public fun (IIIIIIIIIIIIIIZZFFFZLjava/util/List;Ljava/util/List;Ljava/util/List;Ljava/util/List;)V + public synthetic fun (IIIIIIIIIIIIIIZZFFFZLjava/util/List;Ljava/util/List;Ljava/util/List;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()I + public final fun component10 ()I + public final fun component11 ()I + public final fun component12 ()I + public final fun component13 ()I + public final fun component14 ()I + public final fun component15 ()Z + public final fun component16 ()Z + public final fun component17 ()F + public final fun component18 ()F + public final fun component19 ()F + public final fun component2 ()I + public final fun component20 ()Z + public final fun component21 ()Ljava/util/List; + public final fun component22 ()Ljava/util/List; + public final fun component23 ()Ljava/util/List; + public final fun component24 ()Ljava/util/List; + public final fun component3 ()I + public final fun component4 ()I + public final fun component5 ()I + public final fun component6 ()I + public final fun component7 ()I + public final fun component8 ()I + public final fun component9 ()I + public final fun copy (IIIIIIIIIIIIIIZZFFFZLjava/util/List;Ljava/util/List;Ljava/util/List;Ljava/util/List;)Lsk/ainet/models/voxtral/VoxtralCodecMetadata; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralCodecMetadata;IIIIIIIIIIIIIIZZFFFZLjava/util/List;Ljava/util/List;Ljava/util/List;Ljava/util/List;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralCodecMetadata; + public fun equals (Ljava/lang/Object;)Z + public final fun getAcousticCodebookSize ()I + public final fun getAcousticDim ()I + public final fun getCausal ()Z + public final fun getChannels ()I + public final fun getConvWeightNorm ()Z + public final fun getDecoderConvsKernels ()Ljava/util/List; + public final fun getDecoderConvsStrides ()Ljava/util/List; + public final fun getDecoderTransformerLengths ()Ljava/util/List; + public final fun getDecoderWindowSizes ()Ljava/util/List; + public final fun getDim ()I + public final fun getHeadDim ()I + public final fun getHiddenDim ()I + public final fun getInputDim ()I + public final fun getLayerScaleInit ()F + public final fun getNHeads ()I + public final fun getNKVHeads ()I + public final fun getNormEps ()F + public final fun getPatchProjKernelSize ()I + public final fun getPretransformPatchSize ()I + public final fun getQkNorm ()Z + public final fun getQkNormEps ()F + public final fun getSamplingRate ()I + public final fun getSemanticCodebookSize ()I + public final fun getSemanticDim ()I + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralCodecRuntime { + public fun (Ljava/util/Map;Lsk/ainet/models/voxtral/VoxtralCodecMetadata;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;)V + public final fun decode ([I[I)[F +} + +public final class sk/ainet/models/voxtral/VoxtralConfigParser { + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralConfigParser; + public final fun isTiedEmbeddings (Ljava/lang/String;)Z + public final fun parse (Ljava/lang/String;)Lsk/ainet/models/voxtral/VoxtralModelMetadata; + public final fun parseBackbone (Ljava/util/Map;)Lsk/ainet/models/llama/LlamaModelMetadata; +} + +public final class sk/ainet/models/voxtral/VoxtralDefaults { + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralDefaults; + public final fun getACOUSTIC_MODEL ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getAUDIO ()Lsk/ainet/models/voxtral/VoxtralAudioConfig; + public final fun getBACKBONE ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getCODEC ()Lsk/ainet/models/voxtral/VoxtralCodecMetadata; + public final fun getDEFAULT ()Lsk/ainet/models/voxtral/VoxtralModelMetadata; +} + +public final class sk/ainet/models/voxtral/VoxtralFlowMatching { + public fun ()V + public fun (FF)V + public synthetic fun (FFILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun quantizeFSQ (Lsk/ainet/lang/tensor/Tensor;II)[I + public static synthetic fun quantizeFSQ$default (Lsk/ainet/models/voxtral/VoxtralFlowMatching;Lsk/ainet/lang/tensor/Tensor;IIILjava/lang/Object;)[I + public final fun sampleEuler (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIILkotlin/jvm/functions/Function2;Lkotlin/random/Random;)Lsk/ainet/lang/tensor/Tensor; + public static synthetic fun sampleEuler$default (Lsk/ainet/models/voxtral/VoxtralFlowMatching;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIILkotlin/jvm/functions/Function2;Lkotlin/random/Random;ILjava/lang/Object;)Lsk/ainet/lang/tensor/Tensor; + public final fun sampleMidpoint (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIILkotlin/jvm/functions/Function2;Lkotlin/random/Random;)Lsk/ainet/lang/tensor/Tensor; + public static synthetic fun sampleMidpoint$default (Lsk/ainet/models/voxtral/VoxtralFlowMatching;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;IIILkotlin/jvm/functions/Function2;Lkotlin/random/Random;ILjava/lang/Object;)Lsk/ainet/lang/tensor/Tensor; +} + +public final class sk/ainet/models/voxtral/VoxtralGGUFNameResolver : sk/ainet/io/weights/WeightNameResolver { + public fun ()V + public fun resolve (Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralHfTensorNameMapper { + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralHfTensorNameMapper; + public final fun toCanonical (Ljava/lang/String;)Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralModelMetadata { + public fun (Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/voxtral/VoxtralCodecMetadata;Lsk/ainet/models/voxtral/VoxtralAudioConfig;)V + public final fun component1 ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun component2 ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun component3 ()Lsk/ainet/models/voxtral/VoxtralCodecMetadata; + public final fun component4 ()Lsk/ainet/models/voxtral/VoxtralAudioConfig; + public final fun copy (Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/voxtral/VoxtralCodecMetadata;Lsk/ainet/models/voxtral/VoxtralAudioConfig;)Lsk/ainet/models/voxtral/VoxtralModelMetadata; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralModelMetadata;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/models/voxtral/VoxtralCodecMetadata;Lsk/ainet/models/voxtral/VoxtralAudioConfig;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralModelMetadata; + public fun equals (Ljava/lang/Object;)Z + public final fun getAcousticModel ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getAudio ()Lsk/ainet/models/voxtral/VoxtralAudioConfig; + public final fun getBackbone ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getCodec ()Lsk/ainet/models/voxtral/VoxtralCodecMetadata; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader { + public static final field Companion Lsk/ainet/models/voxtral/VoxtralNetworkLoader$Companion; + public fun (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider;Z)V + public synthetic fun (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun buildAcousticRuntime (Lsk/ainet/models/llama/DecoderGgufWeights;Lsk/ainet/lang/nn/Module;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;II)Lsk/ainet/models/voxtral/VoxtralAcousticRuntime; + public final fun getDebug ()Z + public final fun getWeightsProvider ()Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider; +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$Companion { + public final fun fromGguf (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Z)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; + public static synthetic fun fromGguf$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$Companion;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; + public final fun fromGgufRandomAccess (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;Z)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; + public static synthetic fun fromGgufRandomAccess$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$Companion;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ZILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; + public final fun fromSafeTensors (Lsk/ainet/models/llama/LlamaModelMetadata;Lkotlin/jvm/functions/Function0;ZZ)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; + public static synthetic fun fromSafeTensors$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$Companion;Lsk/ainet/models/llama/LlamaModelMetadata;Lkotlin/jvm/functions/Function0;ZZILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; +} + +public abstract interface class sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider { +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufRandomAccess : sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/io/model/QuantPolicy; + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufRandomAccess; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufRandomAccess;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufRandomAccess; + public fun equals (Ljava/lang/Object;)Z + public final fun getQuantPolicy ()Lsk/ainet/io/model/QuantPolicy; + public final fun getRandomAccessProvider ()Lkotlin/jvm/functions/Function0; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufSource : sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/io/model/QuantPolicy; + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufSource; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufSource;Lkotlin/jvm/functions/Function0;Lsk/ainet/io/model/QuantPolicy;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$GgufSource; + public fun equals (Ljava/lang/Object;)Z + public final fun getQuantPolicy ()Lsk/ainet/io/model/QuantPolicy; + public final fun getSourceProvider ()Lkotlin/jvm/functions/Function0; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$Preloaded : sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider { + public fun (Lsk/ainet/models/llama/DecoderGgufWeights;)V + public final fun component1 ()Lsk/ainet/models/llama/DecoderGgufWeights; + public final fun copy (Lsk/ainet/models/llama/DecoderGgufWeights;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$Preloaded; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$Preloaded;Lsk/ainet/models/llama/DecoderGgufWeights;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$Preloaded; + public fun equals (Ljava/lang/Object;)Z + public final fun getWeights ()Lsk/ainet/models/llama/DecoderGgufWeights; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$SafeTensors : sk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider { + public fun (Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V + public final fun component1 ()Lkotlin/jvm/functions/Function0; + public final fun component2 ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun component3 ()Z + public final fun copy (Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;Z)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$SafeTensors; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$SafeTensors;Lkotlin/jvm/functions/Function0;Lsk/ainet/models/llama/LlamaModelMetadata;ZILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider$SafeTensors; + public fun equals (Ljava/lang/Object;)Z + public final fun getMetadata ()Lsk/ainet/models/llama/LlamaModelMetadata; + public final fun getRandomAccessProvider ()Lkotlin/jvm/functions/Function0; + public final fun getTiedEmbeddings ()Z + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralSafeTensorsLoader { + public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun loadAll (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/voxtral/VoxtralSafeTensorsLoader$VoxtralWeights; +} + +public final class sk/ainet/models/voxtral/VoxtralSafeTensorsLoader$VoxtralWeights { + public fun (Lsk/ainet/models/llama/DecoderGgufWeights;Ljava/util/Map;)V + public final fun component1 ()Lsk/ainet/models/llama/DecoderGgufWeights; + public final fun component2 ()Ljava/util/Map; + public final fun copy (Lsk/ainet/models/llama/DecoderGgufWeights;Ljava/util/Map;)Lsk/ainet/models/voxtral/VoxtralSafeTensorsLoader$VoxtralWeights; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralSafeTensorsLoader$VoxtralWeights;Lsk/ainet/models/llama/DecoderGgufWeights;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralSafeTensorsLoader$VoxtralWeights; + public fun equals (Ljava/lang/Object;)Z + public final fun getAllTensors ()Ljava/util/Map; + public final fun getBackbone ()Lsk/ainet/models/llama/DecoderGgufWeights; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralTensorNames { + public static final field ACOUSTIC_INPUT_PROJ Ljava/lang/String; + public static final field ACOUSTIC_INPUT_PROJ_BIAS Ljava/lang/String; + public static final field ACOUSTIC_LLM_PROJ Ljava/lang/String; + public static final field ACOUSTIC_NORM Ljava/lang/String; + public static final field ACOUSTIC_OUTPUT_PROJ Ljava/lang/String; + public static final field ACOUSTIC_OUTPUT_PROJ_BIAS Ljava/lang/String; + public static final field ACOUSTIC_SEMANTIC_OUTPUT Ljava/lang/String; + public static final field ACOUSTIC_TIME_PROJ Ljava/lang/String; + public static final field CODEC_OUTPUT_PROJ Ljava/lang/String; + public static final field CODEC_OUTPUT_PROJ_BIAS Ljava/lang/String; + public static final field CODEC_OUTPUT_PROJ_G Ljava/lang/String; + public static final field CODEC_OUTPUT_PROJ_V Ljava/lang/String; + public static final field CODEC_SEMANTIC_CODEBOOK Ljava/lang/String; + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralTensorNames; + public static final field OUTPUT_NORM Ljava/lang/String; + public static final field OUTPUT_WEIGHT Ljava/lang/String; + public static final field TOKEN_EMBEDDINGS Ljava/lang/String; + public final fun acousticAttnK (I)Ljava/lang/String; + public final fun acousticAttnNorm (I)Ljava/lang/String; + public final fun acousticAttnOut (I)Ljava/lang/String; + public final fun acousticAttnQ (I)Ljava/lang/String; + public final fun acousticAttnV (I)Ljava/lang/String; + public final fun acousticFfnDown (I)Ljava/lang/String; + public final fun acousticFfnGate (I)Ljava/lang/String; + public final fun acousticFfnNorm (I)Ljava/lang/String; + public final fun acousticFfnUp (I)Ljava/lang/String; + public final fun attnK (I)Ljava/lang/String; + public final fun attnNorm (I)Ljava/lang/String; + public final fun attnOut (I)Ljava/lang/String; + public final fun attnQ (I)Ljava/lang/String; + public final fun attnV (I)Ljava/lang/String; + public final fun codecBlockConvBias (I)Ljava/lang/String; + public final fun codecBlockConvG (I)Ljava/lang/String; + public final fun codecBlockConvV (I)Ljava/lang/String; + public final fun codecBlockConvWeight (I)Ljava/lang/String; + public final fun codecTransformerAttnK (II)Ljava/lang/String; + public final fun codecTransformerAttnNorm (II)Ljava/lang/String; + public final fun codecTransformerAttnOut (II)Ljava/lang/String; + public final fun codecTransformerAttnQ (II)Ljava/lang/String; + public final fun codecTransformerAttnScale (II)Ljava/lang/String; + public final fun codecTransformerAttnV (II)Ljava/lang/String; + public final fun codecTransformerFfnDown (II)Ljava/lang/String; + public final fun codecTransformerFfnGate (II)Ljava/lang/String; + public final fun codecTransformerFfnNorm (II)Ljava/lang/String; + public final fun codecTransformerFfnScale (II)Ljava/lang/String; + public final fun codecTransformerFfnUp (II)Ljava/lang/String; + public final fun codecTransformerKNorm (II)Ljava/lang/String; + public final fun codecTransformerQNorm (II)Ljava/lang/String; + public final fun ffnDown (I)Ljava/lang/String; + public final fun ffnGate (I)Ljava/lang/String; + public final fun ffnNorm (I)Ljava/lang/String; + public final fun ffnUp (I)Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralVoice { + public fun (Ljava/lang/String;[FII)V + public synthetic fun (Ljava/lang/String;[FIIILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()[F + public final fun component3 ()I + public final fun component4 ()I + public final fun copy (Ljava/lang/String;[FII)Lsk/ainet/models/voxtral/VoxtralVoice; + public static synthetic fun copy$default (Lsk/ainet/models/voxtral/VoxtralVoice;Ljava/lang/String;[FIIILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralVoice; + public fun equals (Ljava/lang/Object;)Z + public final fun frameEmbedding (I)[F + public final fun getDim ()I + public final fun getEmbeddings ()[F + public final fun getName ()Ljava/lang/String; + public final fun getNumFrames ()I + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/models/voxtral/VoxtralVoiceLoader { + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralVoiceLoader; + public final fun listAvailable (Ljava/nio/file/Path;)Ljava/util/List; + public final fun load (Ljava/nio/file/Path;I)Lsk/ainet/models/voxtral/VoxtralVoice; + public static synthetic fun load$default (Lsk/ainet/models/voxtral/VoxtralVoiceLoader;Ljava/nio/file/Path;IILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralVoice; + public final fun loadFromDir (Ljava/nio/file/Path;Ljava/lang/String;I)Lsk/ainet/models/voxtral/VoxtralVoice; + public static synthetic fun loadFromDir$default (Lsk/ainet/models/voxtral/VoxtralVoiceLoader;Ljava/nio/file/Path;Ljava/lang/String;IILjava/lang/Object;)Lsk/ainet/models/voxtral/VoxtralVoice; +} + +public final class sk/ainet/models/voxtral/VoxtralVoices { + public static final field DEFAULT Ljava/lang/String; + public static final field INSTANCE Lsk/ainet/models/voxtral/VoxtralVoices; + public final fun filename (Ljava/lang/String;)Ljava/lang/String; + public final fun getPRESETS ()Ljava/util/Map; + public final fun list ()Ljava/util/List; +} + diff --git a/llm-providers/api/llm-providers.api b/llm-providers/api/llm-providers.api index 511b63bd..36861f7c 100644 --- a/llm-providers/api/llm-providers.api +++ b/llm-providers/api/llm-providers.api @@ -1,6 +1,6 @@ public final class sk/ainet/llm/providers/SkaiNetChatModel : sk/ainet/llm/api/StreamingChatModel { - public fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ChatTemplate;Lsk/ainet/llm/api/ChatOptions;IILkotlin/random/Random;Ljava/lang/String;)V - public synthetic fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ChatTemplate;Lsk/ainet/llm/api/ChatOptions;IILkotlin/random/Random;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ChatTemplate;Lsk/ainet/llm/api/ChatOptions;ILjava/util/Set;Lkotlin/random/Random;Ljava/lang/String;)V + public synthetic fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Lsk/ainet/apps/kllama/chat/ChatTemplate;Lsk/ainet/llm/api/ChatOptions;ILjava/util/Set;Lkotlin/random/Random;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun call (Lsk/ainet/llm/api/ChatRequest;)Lsk/ainet/llm/api/ChatResponse; public fun close ()V public fun getDefaultOptions ()Lsk/ainet/llm/api/ChatOptions; diff --git a/llm-runtime/kgemma/api/jvm/kgemma.api b/llm-runtime/kgemma/api/jvm/kgemma.api index 2aacce80..955e8e53 100644 --- a/llm-runtime/kgemma/api/jvm/kgemma.api +++ b/llm-runtime/kgemma/api/jvm/kgemma.api @@ -24,6 +24,44 @@ public final class sk/ainet/apps/kgemma/Gemma3nLoadConfig { public fun toString ()Ljava/lang/String; } +public final class sk/ainet/apps/kgemma/Gemma4ChatModel { + public static final field INSTANCE Lsk/ainet/apps/kgemma/Gemma4ChatModel; + public final fun fromSafeTensors (Ljava/lang/String;Lsk/ainet/context/ExecutionContext;Lsk/ainet/llm/api/ChatOptions;Ljava/lang/String;Z)Lsk/ainet/llm/api/StreamingChatModel; + public static synthetic fun fromSafeTensors$default (Lsk/ainet/apps/kgemma/Gemma4ChatModel;Ljava/lang/String;Lsk/ainet/context/ExecutionContext;Lsk/ainet/llm/api/ChatOptions;Ljava/lang/String;ZILjava/lang/Object;)Lsk/ainet/llm/api/StreamingChatModel; +} + +public final class sk/ainet/apps/kgemma/Gemma4Ingestion { + public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/apps/kgemma/Gemma4LoadConfig;)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/apps/kgemma/Gemma4LoadConfig;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun buildDslRuntime (Lsk/ainet/models/gemma/Gemma4Weights;)Lsk/ainet/apps/llm/InferenceRuntime; + public final fun load (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadDslRuntime (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadDslRuntimeFromSafeTensors (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadDslRuntimeStreaming (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadFromSafeTensors (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun loadStreaming (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class sk/ainet/apps/kgemma/Gemma4IngestionJvmKt { + public static final fun loadDslRuntimeNative (Lsk/ainet/apps/kgemma/Gemma4Ingestion;Lkotlin/jvm/functions/Function0;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Ljava/lang/foreign/Arena;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun loadDslRuntimeNativeStreaming (Lsk/ainet/apps/kgemma/Gemma4Ingestion;Lkotlin/jvm/functions/Function0;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Ljava/lang/foreign/Arena;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class sk/ainet/apps/kgemma/Gemma4LoadConfig { + public fun ()V + public fun (Lsk/ainet/io/model/QuantPolicy;Z)V + public synthetic fun (Lsk/ainet/io/model/QuantPolicy;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Lsk/ainet/io/model/QuantPolicy; + public final fun component2 ()Z + public final fun copy (Lsk/ainet/io/model/QuantPolicy;Z)Lsk/ainet/apps/kgemma/Gemma4LoadConfig; + public static synthetic fun copy$default (Lsk/ainet/apps/kgemma/Gemma4LoadConfig;Lsk/ainet/io/model/QuantPolicy;ZILjava/lang/Object;)Lsk/ainet/apps/kgemma/Gemma4LoadConfig; + public fun equals (Ljava/lang/Object;)Z + public final fun getAllowQuantized ()Z + public final fun getQuantPolicy ()Lsk/ainet/io/model/QuantPolicy; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + public final class sk/ainet/apps/kgemma/cli/MainKt { public static final fun main ([Ljava/lang/String;)V } diff --git a/llm-runtime/kllama/api/jvm/kllama.api b/llm-runtime/kllama/api/jvm/kllama.api index da2f02d8..b35cb002 100644 --- a/llm-runtime/kllama/api/jvm/kllama.api +++ b/llm-runtime/kllama/api/jvm/kllama.api @@ -1,6 +1,6 @@ public final class sk/ainet/apps/kllama/CpuAttentionBackend : sk/ainet/models/llama/AttentionBackend { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;FLjava/lang/Integer;)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;FLjava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;FLjava/lang/Integer;Lsk/ainet/apps/llm/RopeType;)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;Lsk/ainet/apps/llm/KvCache;FLjava/lang/Integer;Lsk/ainet/apps/llm/RopeType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun attention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; public fun batchAttention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; public fun reset ()V @@ -23,20 +23,6 @@ public final class sk/ainet/apps/kllama/FusedQKVAccelerator : sk/ainet/models/ll public fun runQKV (ILsk/ainet/lang/tensor/Tensor;)Lsk/ainet/models/llama/GraphAccelerator$QKVResult; } -public final class sk/ainet/apps/kllama/GpuAttentionBackend : sk/ainet/models/llama/AttentionBackend { - public fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/apps/kllama/GpuTensorBridge;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;F)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lsk/ainet/apps/kllama/GpuTensorBridge;Lsk/ainet/models/llama/LlamaRuntimeWeights;Lkotlin/reflect/KClass;FILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun attention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public fun batchAttention (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor; - public fun reset ()V -} - -public abstract interface class sk/ainet/apps/kllama/GpuTensorBridge { - public abstract fun concat (Ljava/util/List;I)Lsk/ainet/lang/tensor/Tensor; - public abstract fun slice (Lsk/ainet/lang/tensor/Tensor;[I[I[I)Lsk/ainet/lang/tensor/Tensor; - public abstract fun sliceUpdate (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;[I[I[I)Lsk/ainet/lang/tensor/Tensor; -} - public final class sk/ainet/apps/kllama/KvCacheJvmKt { public static final fun createOptimalKvCache (III)Lsk/ainet/apps/llm/KvCache; } @@ -183,6 +169,11 @@ public abstract interface class sk/ainet/apps/kllama/chat/java/JavaTool { public abstract fun getDefinition ()Lsk/ainet/apps/kllama/chat/ToolDefinition; } +public final class sk/ainet/apps/kllama/chat/java/JavaTools { + public static final field INSTANCE Lsk/ainet/apps/kllama/chat/java/JavaTools; + public static final fun definition (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Lsk/ainet/apps/kllama/chat/ToolDefinition; +} + public final class sk/ainet/apps/kllama/cli/AgentCli { public fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Ljava/lang/String;Lsk/ainet/apps/kllama/chat/ModelMetadata;)V public synthetic fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;Ljava/lang/String;Lsk/ainet/apps/kllama/chat/ModelMetadata;ILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -251,8 +242,8 @@ public final class sk/ainet/apps/kllama/java/KLlamaJava { } public final class sk/ainet/apps/kllama/java/KLlamaSession : java/lang/AutoCloseable { - public fun (Lsk/ainet/models/llama/LlamaRuntime;Lsk/ainet/apps/llm/Tokenizer;ILjava/lang/String;Ljava/lang/Runnable;)V - public synthetic fun (Lsk/ainet/models/llama/LlamaRuntime;Lsk/ainet/apps/llm/Tokenizer;ILjava/lang/String;Ljava/lang/Runnable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;ILjava/lang/String;Ljava/lang/Runnable;)V + public synthetic fun (Lsk/ainet/apps/llm/InferenceRuntime;Lsk/ainet/apps/llm/Tokenizer;ILjava/lang/String;Ljava/lang/Runnable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close ()V public final fun generate (Ljava/lang/String;)Ljava/lang/String; public final fun generate (Ljava/lang/String;Lsk/ainet/apps/kllama/java/GenerationConfig;)Ljava/lang/String; @@ -264,8 +255,3 @@ public final class sk/ainet/apps/kllama/java/KLlamaSession : java/lang/AutoClose public final fun getSystemPrompt ()Ljava/lang/String; } -public final class sk/ainet/apps/kllama/java/LlamaIngestionBlocking { - public static final fun loadBlocking (Lsk/ainet/apps/kllama/LlamaIngestion;Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaRuntimeWeights; - public static final fun loadStreamingBlocking (Lsk/ainet/apps/kllama/LlamaIngestion;Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaRuntimeWeights; -} - diff --git a/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt b/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt index d3aa4ed3..58fe21bc 100644 --- a/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt +++ b/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt @@ -1,15 +1,16 @@ package sk.ainet.apps.kllama.cli import sk.ainet.apps.kllama.GGUFTokenizer +import sk.ainet.apps.kllama.LlamaIngestion +import sk.ainet.apps.kllama.LlamaLoadConfig import sk.ainet.models.llama.LlamaConfigParser import sk.ainet.apps.llm.OptimizedLLMMode import sk.ainet.apps.llm.OptimizedLLMRuntime import sk.ainet.apps.llm.Tokenizer import sk.ainet.apps.llm.tokenizer.TokenizerFactory import sk.ainet.models.llama.DecoderGgufMemSegConverter -import sk.ainet.models.llama.DecoderSafeTensorsLoader -import sk.ainet.models.llama.LlamaNetworkLoader import sk.ainet.models.llama.LlamaRuntime +import sk.ainet.models.llama.MemSegWeightConverter import sk.ainet.apps.kllama.CpuAttentionBackend import sk.ainet.apps.kllama.Llama2DotCWeightLoader import sk.ainet.models.qwen.QwenNetworkLoader @@ -402,86 +403,69 @@ fun main(args: Array) { bos = convertedWeights.metadata.bosTokenId, ) eosTokenId = convertedWeights.metadata.eosTokenId - } else if (format == ModelFormat.GGUF) { - // --- Llama / Mistral GGUF: DSL path. Mirrors the Qwen branch - // above. LlamaNetworkLoader builds llamaNetwork() (RoPE - // INTERLEAVED, no QK-norm — the LLaMA family default). - val loader = DecoderGgufWeightLoader( - randomAccessProvider = { JvmRandomAccessSource.open(modelPath.toString()) }, - quantPolicy = QuantPolicy.NATIVE_OPTIMIZED, - acceptedArchitectures = LLAMA_COMPATIBLE_ARCHITECTURES, - ) - println("Loading GGUF model from $modelPath (Llama, DSL streaming mode)...") - val rawWeights = loader.loadToMapStreaming(ctx) - - val convertedWeights = if (rawWeights.quantTypes.isNotEmpty()) { - println("Converting ${rawWeights.quantTypes.size} quantized tensors to MemorySegment-backed SIMD format...") - DecoderGgufMemSegConverter.convert(rawWeights, ctx, quantArena) - } else { - rawWeights - } - - if (cliArgs.contextLength != null) { - println("Context length capped to ${cliArgs.contextLength} (model default: ${convertedWeights.metadata.contextLength})") - } - val llamaModel = LlamaNetworkLoader.fromWeights(convertedWeights) - runtime = OptimizedLLMRuntime( - model = llamaModel, - ctx = ctx, - mode = OptimizedLLMMode.DIRECT, - dtype = FP32::class, - bos = convertedWeights.metadata.bosTokenId, - ) - eosTokenId = convertedWeights.metadata.eosTokenId - binVocabSize = convertedWeights.metadata.vocabSize - } else if (format == ModelFormat.SAFETENSORS) { - // --- Llama SafeTensors: DSL path via DecoderSafeTensorsLoader - // (HF tensor names → GGUF-canonical names, BF16/F16 → FP32). - val modelDir = resolveModelDir(modelPath) - val safetensorsFile = if (modelPath.isDirectory()) { - modelDir.resolve("model.safetensors") - } else { - modelPath - } - val configFile = modelDir.resolve("config.json") - if (!configFile.exists()) error("config.json not found in $modelDir") - - println("Loading SafeTensors model from $safetensorsFile...") - val configJson = configFile.readText() - val safeMetadata = LlamaConfigParser.parse(configJson) - val tiedEmbeddings = LlamaConfigParser.isTiedEmbeddings(configJson) - println(" Architecture: ${safeMetadata.architecture}, layers=${safeMetadata.blockCount}, " + - "dim=${safeMetadata.embeddingLength}, heads=${safeMetadata.headCount}, " + - "kvHeads=${safeMetadata.kvHeadCount}, vocab=${safeMetadata.vocabSize}") - if (tiedEmbeddings) println(" Tied word embeddings: output.weight = embed_tokens.weight") - - val safeLoader = DecoderSafeTensorsLoader(ctx, FP32::class, safeMetadata, tiedEmbeddings) - val safeWeights = safeLoader.loadToMap { - JvmRandomAccessSource.open(safetensorsFile.toString()) - } - - if (cliArgs.contextLength != null) { - println("Context length capped to ${cliArgs.contextLength} (model default: ${safeWeights.metadata.contextLength})") - } - val llamaModel = LlamaNetworkLoader.fromWeights(safeWeights) - runtime = OptimizedLLMRuntime( - model = llamaModel, - ctx = ctx, - mode = OptimizedLLMMode.DIRECT, - dtype = FP32::class, - bos = safeWeights.metadata.bosTokenId, - ) - eosTokenId = safeWeights.metadata.eosTokenId - binVocabSize = safeWeights.metadata.vocabSize } else { - // --- BIN (Karpathy llama2.c format): legacy LlamaRuntime path. - // The .bin loader returns LlamaRuntimeWeights directly; the DSL - // path requires DecoderGgufWeights, so this format stays on - // legacy until either Llama2DotCWeightLoader is migrated or - // .bin support is dropped. - println("Loading Karpathy .bin model from $modelPath...") - val runtimeWeights = modelPath.inputStream().use { input -> - Llama2DotCWeightLoader.load(ctx, input.asSource().buffered()) + // --- Llama / SafeTensors / BIN: legacy LlamaRuntime path. + // The DSL path is functionally correct but ~8x slower for Q8/Q4 + // GGUFs because every linearProject forward calls ops.transpose + // on packed quant weights through a generic dispatch (the DSL + // doesn't yet have first-class Q4/Q8 DTypes). Until that lands, + // run Llama through the legacy LlamaRuntime + CpuAttentionBackend + // + MemSegWeightConverter path that previously hit ~2 t/s. + // Qwen GGUF stays on the DSL branch above. + val runtimeWeights = when (format) { + ModelFormat.GGUF -> { + val ingestion = LlamaIngestion( + ctx = ctx, + dtype = FP32::class, + config = LlamaLoadConfig( + quantPolicy = QuantPolicy.NATIVE_OPTIMIZED, + allowQuantized = true, + acceptedArchitectures = LLAMA_COMPATIBLE_ARCHITECTURES, + ), + ) + println("Loading GGUF model from $modelPath (Llama, eager streaming mode)...") + val rawWeights = ingestion.loadStreaming { + JvmRandomAccessSource.open(modelPath.toString()) + } + if (rawWeights.quantTypes.isNotEmpty()) { + println("Converting ${rawWeights.quantTypes.size} quantized tensors to MemorySegment-backed SIMD format...") + MemSegWeightConverter.convert(rawWeights, ctx, quantArena) + } else { + rawWeights + } + } + ModelFormat.SAFETENSORS -> { + val modelDir = resolveModelDir(modelPath) + val safetensorsFile = if (modelPath.isDirectory()) { + modelDir.resolve("model.safetensors") + } else { + modelPath + } + val configFile = modelDir.resolve("config.json") + if (!configFile.exists()) error("config.json not found in $modelDir") + + println("Loading SafeTensors model from $safetensorsFile...") + val configJson = configFile.readText() + val safeMetadata = LlamaConfigParser.parse(configJson) + val tiedEmbeddings = LlamaConfigParser.isTiedEmbeddings(configJson) + println(" Architecture: ${safeMetadata.architecture}, layers=${safeMetadata.blockCount}, " + + "dim=${safeMetadata.embeddingLength}, heads=${safeMetadata.headCount}, " + + "kvHeads=${safeMetadata.kvHeadCount}, vocab=${safeMetadata.vocabSize}") + if (tiedEmbeddings) println(" Tied word embeddings: output.weight = embed_tokens.weight") + + val ingestion = LlamaIngestion(ctx = ctx, dtype = FP32::class) + ingestion.loadSafeTensors( + randomAccessProvider = { JvmRandomAccessSource.open(safetensorsFile.toString()) }, + metadata = safeMetadata, + tiedEmbeddings = tiedEmbeddings, + ) + } + ModelFormat.BIN -> { + println("Loading Karpathy .bin model from $modelPath...") + modelPath.inputStream().use { input -> + Llama2DotCWeightLoader.load(ctx, input.asSource().buffered()) + } + } } if (cliArgs.contextLength != null) { diff --git a/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/ToolCallingDemo.kt b/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/ToolCallingDemo.kt index 89b0aa11..2167729c 100644 --- a/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/ToolCallingDemo.kt +++ b/llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/ToolCallingDemo.kt @@ -67,6 +67,13 @@ public class ToolCallingDemo( println("Prompt: \"$prompt\"") println("---") + println("[Tools] (${registry.definitions().size})") + for (def in registry.definitions()) { + println(" - ${def.name}: ${def.description}") + println(" schema: ${def.parameters}") + } + println("---") + val agentLoop = session.createAgentLoop(registry, maxTokens, temperature) val systemPrompt = """You are a helpful assistant with access to tools. @@ -79,12 +86,32 @@ Always use a tool when one is relevant — do not guess file listings.""" ChatMessage(role = ChatRole.USER, content = prompt) ) + // Render and dump the exact prompt the model will see on round 1, so the + // smoke-test log shows the formatted system+tools+user payload. + val renderedPrompt = session.chatTemplate.apply( + messages = messages, + tools = registry.definitions(), + addGenerationPrompt = true + ) + println("[Prompt → Round 1] (${renderedPrompt.length} chars)") + println("┌──────────────────────────────────────────────────────────────────────┐") + renderedPrompt.lineSequence().forEach { println("│ $it") } + println("└──────────────────────────────────────────────────────────────────────┘") + + var roundIdx = 0 val listener = object : AgentListener { override fun onToken(token: String) { print(token) System.out.flush() } - override fun onAssistantMessage(text: String) { println() } + override fun onAssistantMessage(text: String) { + println() + roundIdx += 1 + println("[Raw Assistant → Round $roundIdx] (${text.length} chars)") + println("┌──────────────────────────────────────────────────────────────────────┐") + text.lineSequence().forEach { println("│ $it") } + println("└──────────────────────────────────────────────────────────────────────┘") + } override fun onToolCalls(calls: List) { for (call in calls) println("[Tool Call] ${call.name}(${call.arguments})") } @@ -93,6 +120,9 @@ Always use a tool when one is relevant — do not guess file listings.""" print("Assistant: ") System.out.flush() } + override fun onToolCallValidationFailed(call: ToolCall, reason: String) { + println("[Tool Call Invalid] ${call.name}(${call.arguments}) -> $reason") + } override fun onComplete(finalResponse: String) {} } @@ -105,6 +135,20 @@ Always use a tool when one is relevant — do not guess file listings.""" listener = listener ) println() + + // Dump the final conversation — exposes the prompts for any + // post-round-1 generation by showing how the message list grew. + println("[Final Conversation] (${messages.size} messages)") + for ((i, msg) in messages.withIndex()) { + val tag = msg.role.name.lowercase() + val calls = msg.toolCalls + val toolSuffix = if (!calls.isNullOrEmpty()) { + " toolCalls=${calls.joinToString { "${it.name}(${it.arguments})" }}" + } else "" + val toolIdSuffix = msg.toolCallId?.let { " toolCallId=$it" } ?: "" + val body = msg.content.replace("\n", "\\n").take(400) + println(" [$i] $tag$toolIdSuffix$toolSuffix: $body") + } } /** diff --git a/tests/smoke/smoke-models.json b/tests/smoke/smoke-models.json index fb6f2e02..ff87a7ef 100644 --- a/tests/smoke/smoke-models.json +++ b/tests/smoke/smoke-models.json @@ -11,6 +11,18 @@ "model": "tinyllama-1.1b-chat-v1.0.Q8_0.gguf", "format": "gguf" }, + { + "name": "Llama-3.2-1B-Instruct-Q8", + "runner": "kllama", + "model": "~/.cache/standapp/models/Llama-3.2-1B-Instruct-Q8_0.gguf", + "format": "gguf", + "instruct": true, + "prompt": "What is the capital of France?", + "toolCalling": { + "prompt": "What is 2 + 2?", + "steps": 256 + } + }, { "name": "Qwen3-1.7B-Q8", "runner": "kllama", @@ -49,6 +61,14 @@ "format": "gguf", "prompt": "The quick brown fox jumps over the lazy dog", "doc": "A pangram is a sentence that contains every letter of the alphabet." + }, + { + "name": "MongoDB-mdbr-leaf-ir", + "runner": "kbert", + "model": "~/.cache/huggingface/hub/models--MongoDB--mdbr-leaf-ir/snapshots/1bb4fc387c49dee1c10c2b22f59db758be87dcaa", + "format": "safetensors", + "prompt": "MongoDB is a NoSQL database", + "doc": "MongoDB stores data in BSON documents" } ] } diff --git a/tests/smoke/smoke-test.sh b/tests/smoke/smoke-test.sh index 218f09f3..7d114473 100755 --- a/tests/smoke/smoke-test.sh +++ b/tests/smoke/smoke-test.sh @@ -238,10 +238,13 @@ print(f'M_INSTRUCT={repr(m.get(\"instruct\", False))}') fail=$((fail + 1)) results+=("FAIL|$M_NAME|$M_RUNNER|$model_size|-|${wall_sec}s") else - tps=$(grep -oE 'tok/s: [0-9.]+' "$output_file" | grep -oE '[0-9.]+' | tail -1) + # `set -euo pipefail` makes a no-match grep fatal; embedding models + # (kbert) don't print tok/s, so allow the substitution to come back + # empty and fall through to "?". + tps=$(grep -oE 'tok/s: [0-9.]+' "$output_file" 2>/dev/null | grep -oE '[0-9.]+' | tail -1 || true) tps=${tps:-"?"} echo -e " ${GREEN}OK${RESET} tok/s: ${CYAN}${tps}${RESET} wall: ${wall_sec}s" - sed -n '/^---$/,/^---$/p' "$output_file" | grep -v '^---$' | head -3 | sed 's/^/ │ /' + sed -n '/^---$/,/^---$/p' "$output_file" | grep -v '^---$' | head -3 | sed 's/^/ │ /' || true pass=$((pass + 1)) results+=("OK|$M_NAME|$M_RUNNER|$model_size|$tps|${wall_sec}s") fi