diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4a6bc331..060f2f82 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -56,6 +56,25 @@ jobs: # Cross-compile jobs (Docker / dockcross) — produce release artifacts, no testing # --------------------------------------------------------------------------- + code-style: + name: Code style (spotless) + package graph + needs: startgate + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-java@v5 + with: + java-version: '21' + distribution: temurin + - name: Spotless check (fail fast on format violations) + run: mvn -B --no-transfer-progress spotless:check + - name: Print internal package dependency graph (jdeps, informational) + continue-on-error: true + run: | + mvn -B --no-transfer-progress -DskipTests -Denforcer.skip=true compile + echo "=== internal package dependency graph (jdeps, bytecode) ===" + jdeps -verbose:package target/classes | grep 'net.ladenthin.llama' || true + crosscompile-linux-x86_64-cuda: name: Cross-Compile manylinux_2_28 x86_64 (CUDA) needs: startgate @@ -794,7 +813,7 @@ jobs: format: jacoco continue-on-error: true - name: Codecov - uses: codecov/codecov-action@v6 + uses: codecov/codecov-action@v7 with: token: ${{ secrets.CODECOV_TOKEN }} files: target/site/jacoco/jacoco.xml @@ -822,7 +841,7 @@ jobs: publish-snapshot: name: Publish Snapshot to Central - needs: [check-snapshot, crosscompile-linux-x86_64-cuda, crosscompile-android-aarch64-opencl] + needs: [check-snapshot, crosscompile-linux-x86_64-cuda, crosscompile-android-aarch64-opencl, code-style] if: needs.check-snapshot.result == 'success' runs-on: ubuntu-latest environment: maven-central @@ -898,7 +917,7 @@ jobs: publish-release: name: Publish Release to Central if: needs.check-tag.result == 'success' - needs: [check-tag, crosscompile-linux-x86_64-cuda, crosscompile-android-aarch64-opencl] + needs: [check-tag, crosscompile-linux-x86_64-cuda, crosscompile-android-aarch64-opencl, code-style] runs-on: ubuntu-latest environment: maven-central permissions: diff --git a/CLAUDE.md b/CLAUDE.md index 4899a5f9..56680cb2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,7 +6,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co Java bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) via JNI, providing a high-level API for LLM inference in Java. The Java layer communicates with a native C++ library through JNI. -Current llama.cpp pinned version: **b9543** +Current llama.cpp pinned version: **b9549** ## Upgrading CUDA Version @@ -303,6 +303,49 @@ be exercised either in CI (via `.github/workflows/publish.yml`) or on a developer machine with HF access; pre-staged models can also be uploaded into `models/` out-of-band. +**Verifying the native library *loads* without models (model-free smoke).** +Even with HuggingFace blocked you can still do the one piece of *real native* +verification that does not need a GGUF: confirm the library loads and its +`JNI_OnLoad` resolves every Java class it looks up by name. The model-gated +tests cannot do this in a restricted sandbox — they self-skip via +`Assume.assumeTrue(model present)` **before** the lib is ever loaded, so a plain +`mvn test` is silent on load-time breakage. The full local recipe: + +```bash +# 1. Build the native lib locally (FetchContent pulls llama.cpp from GitHub, +# which is reachable even when huggingface.co is not): +mvn -q compile +cmake -B build -DBUILD_TESTING=ON +cmake --build build --config Release -j$(nproc) # -> src/main/resources/...///libjllama.so +# 2. Force LlamaModel. (System.load -> JNI_OnLoad) with no model: +mvn test -Dtest=NativeLibraryLoadSmokeTest +``` + +`NativeLibraryLoadSmokeTest` (in the `loader` package) calls +`Class.forName("net.ladenthin.llama.LlamaModel")`, which runs +`LlamaLoader.initialize() -> System.load() -> JNI_OnLoad`, which in turn calls +`FindClass(...)` for every JNI-referenced Java class. It **passes** when the lib +loads cleanly, **fails** if the native-resource path in `LlamaLoader` is wrong +(lib not found) or a `FindClass`/field-signature FQN in +`src/main/cpp/jllama.cpp` is stale after a Java package move (lib loads but +`JNI_OnLoad` throws `NoClassDefFoundError: net/ladenthin/llama/...`), and +**self-skips** when `libjllama` is not on the classpath (pure-Java checkout, no +CMake build) so it never breaks a build-less `mvn test`. + +Both of those failure modes shipped on a branch once — the layered-package +restructure left (a) `LlamaLoader.getNativeResourcePath()` deriving the resource +root from the loader's own package (which moved to `…loader`) and (b) +`jllama.cpp` still `FindClass`-ing the old flat paths — and neither was visible +to a local `mvn test` (model tests skipped) or to the pure-Java unit tests. +**When you move a Java class the JNI layer references by name** (`LlamaModel` +[root], `exception.LlamaException`, `value.LogLevel`, `args.LogFormat`, +`callback.LoadProgressCallback`), update the matching `FindClass` / `"L…;"` +signature string in `src/main/cpp/jllama.cpp` and keep the native-resource root +anchored at `net/ladenthin/llama/` in `LlamaLoader.NATIVE_RESOURCE_BASE` (it must +not track the loader's own Java package). This is the same +"FQN/path not updated after a package move" class as the stale +`spotbugs-exclude.xml`, PIT `targetClasses`, and `CMakeLists.txt` OSInfo repairs. + ### Code Formatting ```bash clang-format -i src/main/cpp/*.cpp src/main/cpp/*.hpp # Format C++ code diff --git a/CMakeLists.txt b/CMakeLists.txt index 605f47a3..0391d119 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,7 +114,7 @@ set(LLAMA_BUILD_APP OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b9543 + GIT_TAG b9549 ) FetchContent_MakeAvailable(llama.cpp) @@ -159,7 +159,7 @@ if(NOT DEFINED OS_NAME) find_package(Java REQUIRED) find_program(JAVA_EXECUTABLE NAMES java) execute_process( - COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes net.ladenthin.llama.OSInfo --os + COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes net.ladenthin.llama.loader.OSInfo --os OUTPUT_VARIABLE OS_NAME OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -177,7 +177,7 @@ if(NOT DEFINED OS_ARCH) find_package(Java REQUIRED) find_program(JAVA_EXECUTABLE NAMES java) execute_process( - COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes net.ladenthin.llama.OSInfo --arch + COMMAND ${JAVA_EXECUTABLE} -cp ${CMAKE_SOURCE_DIR}/target/classes net.ladenthin.llama.loader.OSInfo --arch OUTPUT_VARIABLE OS_ARCH OUTPUT_STRIP_TRAILING_WHITESPACE ) diff --git a/README.md b/README.md index 912de2c1..12ee7bdd 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ **Build:** ![Java 8+](https://img.shields.io/badge/Java-8%2B-informational) ![Platform](https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows%20%7C%20Android-lightgrey) -[![llama.cpp b9543](https://img.shields.io/badge/llama.cpp-%23b9543-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b9543) +[![llama.cpp b9549](https://img.shields.io/badge/llama.cpp-%23b9549-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b9549) [![JPMS](https://img.shields.io/badge/JPMS-modular%20JAR-25A162)](https://openjdk.org/projects/jigsaw/) ![JUnit](https://img.shields.io/badge/tested%20with-JUnit6-25A162) [![JSpecify](https://img.shields.io/badge/JSpecify-1.0.0%20%40NullMarked-25A162)](https://jspecify.dev) diff --git a/TODO.md b/TODO.md index 3e691551..4cb2f7be 100644 --- a/TODO.md +++ b/TODO.md @@ -69,12 +69,41 @@ These are JNI plumbing items for upstream API additions. Policy: add only after (`07109cc`): 25 sites. The same rule is suppressed in BAF (`52c8c95`) for identical reasons. -- **Additional ArchUnit rules to consider** — layered-architecture rules (`layeredArchitecture().consideringAllDependencies()`), per-module banned-imports lists, public-API-surface constraints (no public mutable static state, etc.). Partial progress: `7b6667d` covers the "no public field that is not final" sub-rule. +- **Additional ArchUnit rules to consider** — the full **`layeredArchitecture()`** rule and a **per-module banned-import** rule (`jacksonBannedFromContractsAndLoader` — Jackson kept out of `args`/`callback`/`exception`/`loader`) are now DONE. Still open: more per-module banned-imports if useful, public-API-surface constraints (no public mutable static state, etc.). Partial progress: `7b6667d` covers the "no public field that is not final" sub-rule. - **Cross-repo code-quality TODOs** — see [`../workspace/policies/code-quality-todos.md`](../workspace/policies/code-quality-todos.md) for the canonical `@VisibleForTesting` design-fit review, package hierarchy review, and class/method naming review. This repo has no `@VisibleForTesting` usages today; package and naming reviews remain open. ## Done (kept for history) +### Layered package restructure (flat root package → layered hierarchy) + +The flat `net.ladenthin.llama` root package was split (via `git mv`, history +preserved) into layered packages so boundaries align with the layers, enforced +by a new `layeredArchitecture()` ArchUnit rule (Api → Loader → Marshalling → +Foundation): + +- **Foundation**: `value` (18 DTOs: ChatMessage, ContentPart, Pair, LlamaOutput, + …), `callback` (CancellationToken, LoadProgressCallback, ToolHandler), + `exception` (LlamaException, ModelUnavailableException), `args` (existing leaf). +- **Marshalling**: `json` (response parsers + `TimingsLogger`, its only consumer), + `parameters` (Inference/Model/Json/Cli parameters + `ParameterJsonSerializer` + + `ChatRequest`). +- **Loader** (internal, NOT exported): `loader` (LlamaLoader, OSInfo, + ProcessRunner, NativeLibraryPermissionSetter, Java8CompatibilityHelper, + SkipDownloadFailureTranslator, LlamaSystemProperties). +- **Api** (root): LlamaModel, Session, LlamaIterable, LlamaIterator. + +Cycle-breaking moves: `TimingsLogger` root→`json`, `ParameterJsonSerializer` +`json`→`parameters`, `ChatRequest` root→`parameters` (it carries an +`InferenceParameters` customizer). Test classes mirrored into their subjects' +packages; cross-layer members promoted to `public`. Cross-package Javadoc +`{@link}` references fully-qualified (palantir's `removeUnusedImports` strips +javadoc-only imports). `module-info` exports the new public-API packages and +keeps `loader` internal. All 11 ArchUnit rules green; `javadoc:jar` clean. + +**Breaking change**: public-API FQNs changed (e.g. `net.ladenthin.llama.ChatMessage` +→ `net.ladenthin.llama.value.ChatMessage`) — ship under a major version bump. + - **Reactive `LlamaPublisher` removed in favour of consumer-side adapters.** The hand-rolled `LlamaPublisher` + `LlamaModel.streamPublisher` / `streamChatPublisher` (shipped in PR #188 as §2.3 of the Kotlin SDK @@ -95,7 +124,7 @@ These are JNI plumbing items for upstream API additions. Policy: add only after - **`javac -Werror` + `-Xlint:all,-serial,-options,-classfile,-processing`** — `3e2efbb`. ~20 EP warnings addressed first (EqualsGetClass on `Pair` via instanceof; MissingOverride on `PoolingType` / `RopeScalingType`; JdkObsolete `LinkedList` → `ArrayList` in `LlamaLoader`; StringSplitter inline-suppressed; 3× StringCaseLocaleUsage `Locale.ROOT` in `OSInfo`; EmptyCatch in `OSInfo.isAlpineLinux`; FutureReturnValueIgnored in `LlamaModel.completeAsync`; Finalize on `LlamaModel.finalize`; MixedMutabilityReturnType in 4 parser methods; EnumOrdinal in `InferenceParameters.setMiroStat`; EscapedEntity in `InferenceParameters` javadoc; 4× TypeParameterUnusedInFormals; AnnotateFormatMethod on `Java8CompatibilityHelper.formatted`; SafeVarargs + varargs on `Java8CompatibilityHelper.listOf`). - **`-parameters` javac arg** — `4350cf2`. - **`--release N`** — `4350cf2` (`8`). -- **Mutation-testing threshold enforcement (PIT)** — `62f8a00` + `bb93a8f` (docs) + `3bfa51f` (README badge). "Single class, full plumbing" pattern: PIT runs every CI build with `100`, `` narrowed to `net.ladenthin.llama.Pair`. +- **Mutation-testing threshold enforcement (PIT)** — `62f8a00` + `bb93a8f` (docs) + `3bfa51f` (README badge). Runs every CI build with `100`. **Scope expanded 2026-06-07** from the original single `Pair` target (which was stale after the restructure — `llama.Pair`→`value.Pair` matched nothing) to `value.*` + `exception.*` + `args.*` + `json.TimingsLogger` = 27 classes / 163 mutations, all killed. Still open (optional): `json.ChatResponseParser` / `CompletionResponseParser` private-helper survivors (`RerankResponseParser` is excluded — equivalent empty-list mutant). - **Checker Framework as a second static-nullness pass** — `c63870b`. The original `@PolyNull` on `JsonParameters.toJsonString` was simplified to plain `@Nullable` (the only `@PolyNull` site in production; eliminated in a later cleanup). diff --git a/docs/history/llama-cpp-breaking-changes.md b/docs/history/llama-cpp-breaking-changes.md index 1d370964..0700c8db 100644 --- a/docs/history/llama-cpp-breaking-changes.md +++ b/docs/history/llama-cpp-breaking-changes.md @@ -312,3 +312,12 @@ Used during `llama.cpp` version bumps: when upgrading, scan this file from the r | ~b9495–b9543 | `ggml/src/ggml-cuda/mmvq.cu` + `ggml/src/ggml-cpu/arch/{riscv,wasm}/quants.c` + `ggml/src/ggml-metal/ggml-metal-device.m` + `ggml/src/ggml-opencl/*` + `ggml/src/ggml-sycl/*` + `ggml/src/ggml-vulkan/*` + `ggml/src/ggml-webgpu/*` + `ggml/src/ggml-cpu/kleidiai/kleidiai.cpp` | Per-backend numerical & performance work: (1) CUDA `mul_mat_vec_q_moe` switched to `GGML_CUDA_RESTRICT` aliasing + PDL launch params for Hopper. (2) RISC-V Vector quants: dispatch-by-VL refactor (`vl128` / `vl256` / `vl512` / `vl1024` separate kernels for Q2_K, Q3_K, Q4_K, Q6_K, IQ1_S, IQ1_M, IQ2_S, IQ2_XS, IQ3_S, IQ3_XXS, IQ4_XS, TQ1_0, TQ2_0). (3) WebAssembly SIMD path for Q4_1. (4) Metal residency-set keep-alive polling interval tightened to 5 ms (was 500 ms). (5) OpenCL Adreno: faster `concat`/`cpy`/`get_rows` packed kernels for narrow tensors (`<32` cols); Q6_K mat-vec rewritten with vec4 weight gather. (6) SYCL: multi-column MMVQ paths added for all quant types (ncols=2..8) used by speculative decoding's draft verification batches; `should_reorder_tensor` gate widened from `ne[1]==1` to `ne[1]<=8`. (7) Vulkan: NV cooperative-matrix2 feature detection now requires every `coopmat2_features.*` bit; FWHT shader gains shmem fallback (Intel Windows driver bug workaround). (8) WebGPU: flash-attention split into vector / tile / subgroup-matrix variants with K/V quantization-aware staging (`U32_DEQUANT_HELPERS`); GRANITE_SPEECH bumped to multi-projector. (9) KleidiAI: env vars `GGML_KLEIDIAI_CHUNK_MULTIPLIER` & `GGML_KLEIDIAI_SME` thread-cap auto-detect; SME + non-SME hybrid scheduling. All purely backend-internal; project compiles backends through FetchContent with no API surface change visible to `jllama.cpp`. No project source changes required | | ~b9495–b9543 | `conversion/__init__.py` + `conversion/granite.py` + `conversion/gemma.py` + `convert_lora_to_gguf.py` + `gguf-py/gguf/{constants,tensor_mapping,gguf_writer}.py` | Python-side: new `Granite4VisionMmprojModel` (vision-projector for Granite4 with QFormer-window deepstack + per-projector spatial offsets + image-grid pinpoints); Gemma4 unified vision/audio conversion fix-ups for newer HF checkpoints (`hidden_size` falls back to `audio_embed_dim`; `model_patch_size` falls back to `patch_size * pooling_kernel_size`). `convert_lora_to_gguf.py` gained `--trust-remote-code`. New `LLM_KV_DEEPSTACK_MAPPING` writer (`add_deepstack_mapping`) and new clip-vision keys (`KEY_PROJ_SAMPLE_QUERY_SIDE`, `KEY_PROJ_SAMPLE_WINDOW_SIDE`, `KEY_PROJ_SPATIAL_OFFSETS`, `KEY_FEATURE_LAYERS`, `KEY_IMAGE_GRID_PINPOINTS`) for the Granite4 vision projector. Python-side only; no impact on the Java/JNI build. No project source changes required | | ~b9495–b9543 | upstream build / verification | Local build pending: the b9495 → b9543 bump is expected to compile cleanly given the audit above (zero `grep` matches in `src/main/cpp/` for any of the renamed or removed symbols: `hparams.n_layer`, `nextn_predict_layers`, `n_layer_nextn`, `n_layer_all`, `LLAMA_STATE_SEQ_FLAGS_ON_DEVICE`, `clip_image_u8`/`clip_image_f32` field access, `clip_build_img_from_pixels`, `clip_get_newline_tensor`, `clip_image_u8_get_data`, `clip_embd_nbytes`, `clip_embd_nbytes_by_img`, `clip_encode_float_image`, `clip_image_f32_batch_add_mel`, `mtmd_helper_bitmap_init_from_file`, `mtmd_helper_bitmap_init_from_buf`, `common_imatrix_load`). The only project-visible signature change — `process_mtmd_prompt()`'s new `bool is_placeholder` parameter — is defaulted, so existing call sites inside the project compile unchanged. All breaking changes in this range are absorbed inside upstream-compiled translation units; no project source edits required for the version bump itself | +| ~b9543–b9549 | `include/llama.h` + `src/llama-context.{h,cpp}` + `src/llama-cparams.h` + `src/llama-ext.h` | New `llama_context_params::ctx_other` field (a source/target/parent `llama_context *`, default `nullptr`) used to share results or `llama_memory` between two contexts; mirrored by new `cparams.ctx_other` and the new staging API `llama_get_ctx_other()` (`llama-ext.h`). `llama_get_memory()` was moved earlier in `llama-context.cpp` and made null-safe (returns `nullptr` for a null ctx). `llama_context_default_params()` initializes `ctx_other = nullptr`. Project does not aggregate-init `llama_context_params` (it goes through `llama_context_default_params()` inside upstream `server-context.cpp`) and never includes `llama-ext.h` — verified via `grep -rn "llama_context_params\|ctx_other\|llama-ext.h\|llama_get_ctx_other\|llama_get_memory" src/main/cpp/` returns zero matches. No project source changes required | +| ~b9543–b9549 | `src/llama-kv-cache.{h,cpp}` + `llama-kv-cache-iswa.{h,cpp}` + `llama-kv-cache-dsa.cpp` + `llama-memory.h` + `llama-memory-hybrid{,-iswa}.cpp` | KV-cache constructors gained two new parameters: `llama_memory_t mem_other` and `layer_share_cb share` (`std::function` returning the source layer index to share cells from, or negative to skip). Enables one context's KV cache to share cells with another's (used by the new Gemma4-assistant MTP head). `llama_memory_params` gained a `mem_other` field. All call sites (iswa/dsa/hybrid wrappers, `llama_model::create_memory`) updated upstream; the project never constructs a `llama_kv_cache*` or `llama_memory_*` directly. No project source changes required | +| ~b9543–b9549 | `src/llama-arch.{h,cpp}` + new `src/models/gemma4-assistant.cpp` + `src/models/models.h` + `src/llama-model.{h,cpp}` + `src/llama-hparams.{h,cpp}` + `src/llama-graph.{h,cpp}` + `gguf-py/` + `conversion/gemma.py` | **New model architecture `LLM_ARCH_GEMMA4_ASSISTANT` ("gemma4-assistant")** — a NextN/MTP draft "assistant" head that shares the target Gemma4's KV cache and reads its post-final-norm hidden state. New tensors `LLM_TENSOR_NEXTN_PROJ_PRE`/`NEXTN_PROJ_POST` (`nextn.pre_projection`/`post_projection`) plus model-level `nextn_proj_pre`/`nextn_proj_post`; new hparams `n_embd_inp_impl` (input-embedding dim override, honoured by `n_embd_inp()`) and graph field `n_layer_nextn`. Python conversion registers `Gemma4AssistantForCausalLM`/`Gemma4UnifiedAssistantForCausalLM`. This is the headline new feature; it is a speculative-decoding / **MTP** mechanism, which this project tracks as deferred-by-policy (see Open TODOs / `spec-draft-backend-sampling` + MTP). Consumed entirely inside upstream-compiled TUs — loading a non-assistant GGUF is unaffected. No project source changes required to build; exposing MTP through the Java API remains the existing deferred TODO | +| ~b9543–b9549 | `common/chat.cpp` + new `models/templates/LFM2.5-8B-A1B.jinja` | LFM2 chat-template handling: prior-turn `reasoning_content` is now copied into the template's `thinking` field, and `` reasoning extraction is gated on the template source actually containing `` (and no longer on `enable_thinking`). New `LFM2.5-8B-A1B` template + parser test consolidation. Routing happens inside upstream-compiled `chat.cpp`; the project calls no `common_chat_params_init_lfm2*` symbol. Handled automatically when such a model is loaded; no project source or Java API changes required | +| ~b9543–b9549 | `common/arg.cpp` + `common/speculative.cpp` + `src/llama-graph.cpp` | `common_params_handle_models()` mmproj auto-download now also requires `params.mmproj.path.empty() && params.mmproj.url.empty()` (an explicitly-specified mmproj is no longer re-downloaded). `speculative.cpp` MTP path adds a shared-memory fast path (`is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt`) that skips the catch-up decode and reuses the target position for draft tokens (Gemma4 assistant), and switched to `llama_model_n_embd_out()` for the MTP row width. `llama-graph.cpp` moved the `set_input_kq_mask` / `can_reuse_kq_mask` calls out of the k-idxs-buffer guard (iswa/hybrid-iswa mask bugfix). All inside upstream-compiled TUs; no project source changes required | +| ~b9543–b9549 | `tools/server/server-context.cpp` (project-linked) | The one project-linked server TU changed: now `#include`s `ggml-cpp.h` and `../../src/llama-ext.h`; sets `cparams.ctx_other = ctx_tgt` for MTP draft/MTP contexts; moved the `ctx_dft_seq_rm_type = common_context_can_seq_rm(...)` assignment to after context init (guarded by `if (ctx_dft)`); downgraded the spec memory-measure failure log from `SRV_ERR` to `SRV_WRN`; and gated the mtmd draft-processing block on `llama_get_ctx_other(ctx_dft) != ctx_tgt`. All changes are internal to the TU and the new includes resolve against the FetchContent'd `src/` and `ggml` headers. Compiles into `jllama` unchanged from the project's side. No project source changes required | +| ~b9543–b9549 | `.github/workflows/docker.yml` (upstream CI) | Upstream's `cuda13` Docker image bumped from CUDA `13.1.1` to `13.3.0`. Upstream's own CI only; this project ships its own `publish.yml` and pins CUDA 13.2 via `.github/build_cuda_linux.sh` (see CLAUDE.md "Upgrading CUDA Version"). No impact | +| ~b9543–b9549 | project `CMakeLists.txt` (pre-existing latent bug, fixed in this bump) | **Not an upstream change** — surfaced while build-testing this bump locally. The OS/arch detection block invoked `net.ladenthin.llama.OSInfo`, but the class had moved to `net.ladenthin.llama.loader.OSInfo` in the earlier layered-package restructure, so `cmake -B build` failed with "Could not determine OS name" on any host that does not pass `-DOS_NAME`/`-DOS_ARCH` explicitly (CI does, which is why it went unnoticed). Fixed both `execute_process` invocations (`--os` and `--arch`) to the `loader.OSInfo` FQN. Same stale-FQN-after-restructure class as the earlier `spotbugs-exclude.xml` / PIT-`targetClasses` repairs — the standing reminder to re-validate every FQN-bearing config after a package move now also covers `CMakeLists.txt` | +| ~b9543–b9549 | upstream build / verification | Local build with `GIT_TAG b9549` verified clean on Linux x86_64: `cmake -B build -DBUILD_TESTING=ON` configures cleanly (after the `loader.OSInfo` FQN fix above), `cmake --build build --config Release -j$(nproc)` links `libjllama.so` + `jllama_test` with zero warnings on any project translation unit (incl. the changed `server-context.cpp`), and `ctest --test-dir build --output-on-failure` reports 435/435 tests passing. All upstream breaking changes in this range are absorbed inside upstream-compiled translation units; no project C++ source edits were required for the version bump itself | diff --git a/pom.xml b/pom.xml index c1ab3c89..a7cfb882 100644 --- a/pom.xml +++ b/pom.xml @@ -61,6 +61,7 @@ SPDX-License-Identifier: MIT 1.5.34 1.27 6.1.0 + 3.0 1.37 0.16 3.6 @@ -113,6 +114,12 @@ SPDX-License-Identifier: MIT ${junit.version} test + + org.hamcrest + hamcrest + ${hamcrest.version} + test + net.jqwik jqwik @@ -632,12 +639,12 @@ SPDX-License-Identifier: MIT org.pitest pitest-maven @@ -650,10 +657,19 @@ SPDX-License-Identifier: MIT - net.ladenthin.llama.Pair + net.ladenthin.llama.value.* + net.ladenthin.llama.exception.* + net.ladenthin.llama.args.* + net.ladenthin.llama.json.TimingsLogger + net.ladenthin.llama.json.RerankResponseParser + net.ladenthin.llama.json.ChatResponseParser + net.ladenthin.llama.json.CompletionResponseParser - net.ladenthin.llama.PairTest + net.ladenthin.llama.value.* + net.ladenthin.llama.exception.* + net.ladenthin.llama.args.* + net.ladenthin.llama.json.* 100 30000 diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index 98b9eb9e..09d420ff 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -16,7 +16,7 @@ SPDX-License-Identifier: MIT upstream fixes should land in xerial/sqlite-jdbc rather than be patched here. --> - + - + @@ -75,7 +75,7 @@ SPDX-License-Identifier: MIT suppressions above. --> - + @@ -97,7 +97,7 @@ SPDX-License-Identifier: MIT emit a nonsense JSON value the native code would reject. --> - + @@ -116,8 +116,8 @@ SPDX-License-Identifier: MIT --> - - + + @@ -143,7 +143,7 @@ SPDX-License-Identifier: MIT there is no meaningful "allowed root" to validate against. --> - + @@ -261,7 +261,7 @@ SPDX-License-Identifier: MIT the wrapping is verified by tests, so the finding is a false positive. --> - + @@ -297,8 +297,8 @@ SPDX-License-Identifier: MIT --> - - + + @@ -313,7 +313,7 @@ SPDX-License-Identifier: MIT mismatch is the public contract. --> - + @@ -326,7 +326,7 @@ SPDX-License-Identifier: MIT format argument; the wrapper is the documented escape hatch. --> - + @@ -342,7 +342,7 @@ SPDX-License-Identifier: MIT no behavioural benefit. --> - + @@ -355,7 +355,7 @@ SPDX-License-Identifier: MIT there is no additional state-dependent context to add at this guard. --> - + diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 9e5b4c2b..0836ea32 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -479,8 +479,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { c_integer = env->FindClass("java/lang/Integer"); c_float = env->FindClass("java/lang/Float"); c_biconsumer = env->FindClass("java/util/function/BiConsumer"); - c_llama_error = env->FindClass("net/ladenthin/llama/LlamaException"); - c_log_level = env->FindClass("net/ladenthin/llama/LogLevel"); + c_llama_error = env->FindClass("net/ladenthin/llama/exception/LlamaException"); + c_log_level = env->FindClass("net/ladenthin/llama/value/LogLevel"); c_log_format = env->FindClass("net/ladenthin/llama/args/LogFormat"); c_error_oom = env->FindClass("java/lang/OutOfMemoryError"); @@ -527,10 +527,10 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { // find fields f_model_pointer = env->GetFieldID(c_llama_model, "ctx", "J"); f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); - f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lnet/ladenthin/llama/LogLevel;"); - f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lnet/ladenthin/llama/LogLevel;"); - f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lnet/ladenthin/llama/LogLevel;"); - f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lnet/ladenthin/llama/LogLevel;"); + f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lnet/ladenthin/llama/value/LogLevel;"); + f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lnet/ladenthin/llama/value/LogLevel;"); + f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lnet/ladenthin/llama/value/LogLevel;"); + f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lnet/ladenthin/llama/value/LogLevel;"); f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lnet/ladenthin/llama/args/LogFormat;"); f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lnet/ladenthin/llama/args/LogFormat;"); diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java index 6860292e..af93f9af 100644 --- a/src/main/java/module-info.java +++ b/src/main/java/module-info.java @@ -10,7 +10,7 @@ * ({@code net.ladenthin.llama}, {@code net.ladenthin.llama.args}, * {@code net.ladenthin.llama.json}). The native libraries shipped under * {@code /net/ladenthin/llama/{OS}/{ARCH}/} are loaded by - * {@link net.ladenthin.llama.LlamaLoader} via + * {@link net.ladenthin.llama.loader.LlamaLoader} via * {@link Class#getResourceAsStream(String)} on its own class object, so the resources * are looked up in this module and do not need to be {@code opens}'d.

* @@ -48,5 +48,11 @@ exports net.ladenthin.llama; exports net.ladenthin.llama.args; + exports net.ladenthin.llama.callback; + exports net.ladenthin.llama.exception; exports net.ladenthin.llama.json; + exports net.ladenthin.llama.parameters; + exports net.ladenthin.llama.value; +// net.ladenthin.llama.loader is intentionally NOT exported: native-library loading, +// OS detection and process/system-property infrastructure are internal to the module. } diff --git a/src/main/java/net/ladenthin/llama/LlamaIterable.java b/src/main/java/net/ladenthin/llama/LlamaIterable.java index 1e1ade6a..08a75314 100644 --- a/src/main/java/net/ladenthin/llama/LlamaIterable.java +++ b/src/main/java/net/ladenthin/llama/LlamaIterable.java @@ -6,6 +6,8 @@ package net.ladenthin.llama; import lombok.ToString; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.value.LlamaOutput; /** * An {@link Iterable} wrapper around {@link LlamaIterator} returned by diff --git a/src/main/java/net/ladenthin/llama/LlamaIterator.java b/src/main/java/net/ladenthin/llama/LlamaIterator.java index 2fb0c86e..4200c34b 100644 --- a/src/main/java/net/ladenthin/llama/LlamaIterator.java +++ b/src/main/java/net/ladenthin/llama/LlamaIterator.java @@ -9,6 +9,8 @@ import java.util.NoSuchElementException; import lombok.ToString; import net.ladenthin.llama.json.CompletionResponseParser; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.value.LlamaOutput; /** * This iterator is used by {@link LlamaModel#generate(InferenceParameters)} and diff --git a/src/main/java/net/ladenthin/llama/LlamaModel.java b/src/main/java/net/ladenthin/llama/LlamaModel.java index 695c2b68..eacd589d 100644 --- a/src/main/java/net/ladenthin/llama/LlamaModel.java +++ b/src/main/java/net/ladenthin/llama/LlamaModel.java @@ -15,9 +15,28 @@ import java.util.function.BiConsumer; import lombok.ToString; import net.ladenthin.llama.args.LogFormat; +import net.ladenthin.llama.callback.CancellationToken; +import net.ladenthin.llama.callback.LoadProgressCallback; +import net.ladenthin.llama.callback.ToolHandler; +import net.ladenthin.llama.exception.LlamaException; import net.ladenthin.llama.json.ChatResponseParser; import net.ladenthin.llama.json.CompletionResponseParser; import net.ladenthin.llama.json.RerankResponseParser; +import net.ladenthin.llama.loader.LlamaLoader; +import net.ladenthin.llama.loader.SkipDownloadFailureTranslator; +import net.ladenthin.llama.parameters.ChatRequest; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.CompletionResult; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.LogLevel; +import net.ladenthin.llama.value.ModelMeta; +import net.ladenthin.llama.value.Pair; +import net.ladenthin.llama.value.ServerMetrics; +import net.ladenthin.llama.value.StopReason; +import net.ladenthin.llama.value.ToolCall; import org.jspecify.annotations.Nullable; /** @@ -29,7 +48,7 @@ *
    *
  • Streaming answers (and probabilities) via {@link #generate(InferenceParameters)}
  • *
  • Creating whole responses to prompts via {@link #complete(InferenceParameters)}
  • - *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#enableEmbedding()}
  • + *
  • Creating embeddings via {@link #embed(String)} (make sure to configure {@link net.ladenthin.llama.parameters.ModelParameters#enableEmbedding()}
  • *
  • Accessing the tokenizer via {@link #encode(String)} and {@link #decode(int[])}
  • *
* @@ -57,19 +76,19 @@ public class LlamaModel implements AutoCloseable { private final RerankResponseParser rerankParser = new RerankResponseParser(); /** - * Load with the given {@link ModelParameters}. Make sure to either set + * Load with the given {@link net.ladenthin.llama.parameters.ModelParameters}. Make sure to either set *
    - *
  • {@link ModelParameters#setModel(String)}
  • - *
  • {@link ModelParameters#setModelUrl(String)}
  • - *
  • {@link ModelParameters#setHfRepo(String)}, {@link ModelParameters#setHfFile(String)}
  • + *
  • {@link net.ladenthin.llama.parameters.ModelParameters#setModel(String)}
  • + *
  • {@link net.ladenthin.llama.parameters.ModelParameters#setModelUrl(String)}
  • + *
  • {@link net.ladenthin.llama.parameters.ModelParameters#setHfRepo(String)}, {@link net.ladenthin.llama.parameters.ModelParameters#setHfFile(String)}
  • *
* * @param parameters the set of options - * @throws ModelUnavailableException if {@link ModelParameters#setSkipDownload(boolean) + * @throws net.ladenthin.llama.exception.ModelUnavailableException if {@link net.ladenthin.llama.parameters.ModelParameters#setSkipDownload(boolean) * setSkipDownload(true)} (or * {@link net.ladenthin.llama.args.ModelFlag#SKIP_DOWNLOAD}) * is set and the configured model file is missing or invalid - * @throws LlamaException for any other load failure + * @throws net.ladenthin.llama.exception.LlamaException for any other load failure */ // loadModel is a native method; it does not call back into Java with this, // so the @UnderInitialization receiver warning is a CF false positive. @@ -86,11 +105,11 @@ public LlamaModel(ModelParameters parameters) { * Load the model and forward progress updates to {@code progress}. The callback is * invoked synchronously on the constructor thread by the native loader and may * return {@code false} to abort the load (in which case this constructor throws - * {@link LlamaException}). + * {@link net.ladenthin.llama.exception.LlamaException}). * * @param parameters the set of options * @param progress load progress sink; {@code null} disables the callback - * @throws LlamaException if loading fails or the callback aborts + * @throws net.ladenthin.llama.exception.LlamaException if loading fails or the callback aborts */ // loadModel / loadModelWithProgress are native methods; they do not call back // into Java with this, so the @UnderInitialization receiver warning is a CF @@ -124,13 +143,13 @@ public String complete(InferenceParameters parameters) { /** * Typed variant of {@link #complete(InferenceParameters)} that surfaces per-completion - * {@link Usage}, {@link Timings}, {@link TokenLogprob} entries, and {@link StopReason}. + * {@link net.ladenthin.llama.value.Usage}, {@link net.ladenthin.llama.value.Timings}, {@link net.ladenthin.llama.value.TokenLogprob} entries, and {@link net.ladenthin.llama.value.StopReason}. *

- * Logprobs are populated only when {@link InferenceParameters#withNProbs(int)} is > 0. - * The raw native JSON is preserved on {@link CompletionResult#getRawJson()}. + * Logprobs are populated only when {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} is > 0. + * The raw native JSON is preserved on {@link net.ladenthin.llama.value.CompletionResult#getRawJson()}. * * @param parameters the inference configuration - * @return a populated {@link CompletionResult} + * @return a populated {@link net.ladenthin.llama.value.CompletionResult} */ public CompletionResult completeWithStats(InferenceParameters parameters) { InferenceParameters nonStreaming = parameters.withStream(false); @@ -141,7 +160,7 @@ public CompletionResult completeWithStats(InferenceParameters parameters) { /** * Cancellable variant of {@link #complete(InferenceParameters)}. Runs in streaming mode - * internally so the inference loop can observe a {@link CancellationToken#cancel()} call + * internally so the inference loop can observe a {@link net.ladenthin.llama.callback.CancellationToken#cancel()} call * from another thread and return early with whatever text was accumulated so far. *

* The token is rebound to this call (any prior {@code cancel} state is cleared on entry). @@ -149,13 +168,13 @@ public CompletionResult completeWithStats(InferenceParameters parameters) { *

* * @param parameters the inference configuration (its {@code stream} flag will be set to true) - * @param token cancellation handle; {@link CancellationToken#cancel()} aborts the loop + * @param token cancellation handle; {@link net.ladenthin.llama.callback.CancellationToken#cancel()} aborts the loop * @return the text generated up to the point of stop or cancellation */ /** * Dispatch a list of completion requests in parallel and return the generated texts * in the same order. Each request is sent immediately; the native scheduler dispatches - * tasks across whatever slot count {@link ModelParameters#setParallel(int)} was + * tasks across whatever slot count {@link net.ladenthin.llama.parameters.ModelParameters#setParallel(int)} was * configured with. With a default single-slot model the requests still run, but * sequentially. * @@ -177,7 +196,7 @@ public java.util.List completeBatch(java.util.Collection completeBatchWithStats(java.util.Collect /** * Dispatch a list of typed chat requests in parallel and return the parsed responses - * in the same order. Requires {@link ModelParameters#setParallel(int)} > 1 for + * in the same order. Requires {@link net.ladenthin.llama.parameters.ModelParameters#setParallel(int)} > 1 for * actual parallelism; otherwise the calls run sequentially on the single slot. * * @param requests the typed chat requests (must be distinct instances) @@ -216,7 +235,6 @@ public java.util.List chatBatch(java.util.Collection return out; } - /** * Asynchronous variant of {@link #complete(InferenceParameters)}. Runs the inference on * the common {@link java.util.concurrent.ForkJoinPool} so it does not block the calling @@ -232,8 +250,8 @@ public CompletableFuture completeAsync(InferenceParameters parameters) { /** * Cancellable async variant. The returned future is wired to the supplied - * {@link CancellationToken}: calling {@code future.cancel(true)} also invokes - * {@link CancellationToken#cancel()} so the inference loop returns early. + * {@link net.ladenthin.llama.callback.CancellationToken}: calling {@code future.cancel(true)} also invokes + * {@link net.ladenthin.llama.callback.CancellationToken#cancel()} so the inference loop returns early. * * @param parameters the inference configuration * @param token cancellation handle bound to the underlying inference loop @@ -278,7 +296,7 @@ public CompletableFuture chatCompleteTextAsync(InferenceParameters param /** * Cancellable variant of {@link #complete(InferenceParameters)}. Runs in streaming mode - * internally so the inference loop can observe a {@link CancellationToken#cancel()} call + * internally so the inference loop can observe a {@link net.ladenthin.llama.callback.CancellationToken#cancel()} call * from another thread between token boundaries and return early with whatever text was * accumulated so far. * @@ -341,7 +359,7 @@ public LlamaIterable generate(InferenceParameters parameters) { * * @param prompt the string to embed * @return an embedding float array - * @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters#enableEmbedding()}) + * @throws IllegalStateException if embedding mode was not activated (see {@link net.ladenthin.llama.parameters.ModelParameters#enableEmbedding()}) */ public native float[] embed(String prompt); @@ -422,7 +440,7 @@ protected final void finalize() { private static native byte[] jsonSchemaToGrammarBytes(String schema); /** - * Converts a JSON schema to a grammar string usable by {@link ModelParameters#setGrammar(String)}. + * Converts a JSON schema to a grammar string usable by {@link net.ladenthin.llama.parameters.ModelParameters#setGrammar(String)}. * * @param schema the JSON schema as a string * @return the converted grammar string @@ -449,7 +467,7 @@ public List> rerank(boolean reRank, String query, String... } /** - * Rerank the given documents against the query, returning a {@link LlamaOutput} with scored documents + * Rerank the given documents against the query, returning a {@link net.ladenthin.llama.value.LlamaOutput} with scored documents * in the probabilities map. * * @param query the query string @@ -505,7 +523,7 @@ public String applyTemplate(InferenceParameters parameters) { * * @param parameters the inference parameters including messages * @return the model's response as a JSON string containing the completion result - * @throws LlamaException if the model was loaded in embedding mode or if inference fails + * @throws net.ladenthin.llama.exception.LlamaException if the model was loaded in embedding mode or if inference fails */ public String chatComplete(InferenceParameters parameters) { InferenceParameters nonStreaming = parameters.withStream(false); @@ -520,23 +538,22 @@ public String chatComplete(InferenceParameters parameters) { * * @param parameters the inference parameters including messages * @return the assistant's reply text (extracted from {@code choices[0].message.content}) - * @throws LlamaException if the model was loaded in embedding mode or if inference fails + * @throws net.ladenthin.llama.exception.LlamaException if the model was loaded in embedding mode or if inference fails */ public String chatCompleteText(InferenceParameters parameters) { return chatParser.extractChoiceContent(chatComplete(parameters)); } /** - * Typed chat completion: serialize a {@link ChatRequest} (with optional tools), call - * the native chat endpoint, and return a parsed {@link ChatResponse} carrying typed - * {@link Usage}, {@link Timings}, and {@link ChatChoice} list. + * Typed chat completion: serialize a {@link net.ladenthin.llama.parameters.ChatRequest} (with optional tools), call + * the native chat endpoint, and return a parsed {@link net.ladenthin.llama.value.ChatResponse} carrying typed + * {@link net.ladenthin.llama.value.Usage}, {@link net.ladenthin.llama.value.Timings}, and {@link net.ladenthin.llama.value.ChatChoice} list. * * @param request the typed request (messages + optional tools) * @return the parsed typed response */ public ChatResponse chat(ChatRequest request) { - InferenceParameters params = InferenceParameters.empty() - .withMessagesJson(request.buildMessagesJson()); + InferenceParameters params = InferenceParameters.empty().withMessagesJson(request.buildMessagesJson()); Optional toolsJsonOpt = request.buildToolsJson(); if (toolsJsonOpt.isPresent()) { params = params.withToolsJson(toolsJsonOpt.get()).withUseChatTemplate(true); @@ -552,10 +569,10 @@ public ChatResponse chat(ChatRequest request) { /** * Tool-calling agent loop. Repeatedly calls {@link #chat(ChatRequest)}; on each - * response that includes {@code tool_calls}, invokes the matching {@link ToolHandler} + * response that includes {@code tool_calls}, invokes the matching {@link net.ladenthin.llama.callback.ToolHandler} * for every call, appends the assistant turn and tool-result turns to the request's * message list, and loops until either the model responds without tool calls or the - * round cap from {@link ChatRequest#getMaxToolRounds()} is reached. + * round cap from {@link net.ladenthin.llama.parameters.ChatRequest#getMaxToolRounds()} is reached. *

* Handler exceptions are caught and reported back to the model as * {@code {"error":"..."}} tool results so the loop can continue. Unknown tool names @@ -564,7 +581,7 @@ public ChatResponse chat(ChatRequest request) { * * @param request the typed request; must declare tools that the model can call * @param handlers map from tool name to handler - * @return the final {@link ChatResponse} when the model stops issuing tool calls + * @return the final {@link net.ladenthin.llama.value.ChatResponse} when the model stops issuing tool calls * (or the last response when the round cap is hit) */ public ChatResponse chatWithTools(ChatRequest request, java.util.Map handlers) { @@ -626,7 +643,7 @@ public ChatResponse chatWithTools(ChatRequest request, java.util.Map * Callers are responsible for producing a JSON Schema that matches the target type; @@ -716,7 +733,7 @@ public String getMetrics() { * @param parameters inference parameters (a new derivation with the schema set is used) * @param target type * @return parsed POJO of type {@code T} - * @throws LlamaException when the response is not valid JSON for the target type + * @throws net.ladenthin.llama.exception.LlamaException when the response is not valid JSON for the target type */ public T completeAsJson(Class type, String schema, InferenceParameters parameters) { return completeAsJson(type, parameters.withJsonSchema(schema)); @@ -725,15 +742,15 @@ public T completeAsJson(Class type, String schema, InferenceParameters pa /** * Run {@link #complete(InferenceParameters)} and deserialize the result as JSON into * {@code type}. The {@code parameters} object should already have a JSON Schema set - * via {@link InferenceParameters#withJsonSchema(String)} or a grammar via - * {@link InferenceParameters#withGrammar(String)} — otherwise the model output is + * via {@link net.ladenthin.llama.parameters.InferenceParameters#withJsonSchema(String)} or a grammar via + * {@link net.ladenthin.llama.parameters.InferenceParameters#withGrammar(String)} — otherwise the model output is * unlikely to parse. * * @param type the target POJO class for Jackson deserialization * @param parameters inference parameters (schema/grammar already set by the caller) * @param target type * @return parsed POJO of type {@code T} - * @throws LlamaException when the response is not valid JSON for the target type + * @throws net.ladenthin.llama.exception.LlamaException when the response is not valid JSON for the target type */ public T completeAsJson(Class type, InferenceParameters parameters) { String raw = complete(parameters); @@ -747,11 +764,11 @@ public T completeAsJson(Class type, InferenceParameters parameters) { /** * Typed accessor for {@link #getMetrics()}. Parses the raw JSON into a - * {@link ServerMetrics} view that exposes cumulative {@link Usage} and - * {@link Timings}, slot counts, and a passthrough to the underlying JSON. + * {@link net.ladenthin.llama.value.ServerMetrics} view that exposes cumulative {@link net.ladenthin.llama.value.Usage} and + * {@link net.ladenthin.llama.value.Timings}, slot counts, and a passthrough to the underlying JSON. * - * @return parsed {@link ServerMetrics} - * @throws LlamaException if the native call fails or the response cannot be parsed + * @return parsed {@link net.ladenthin.llama.value.ServerMetrics} + * @throws net.ladenthin.llama.exception.LlamaException if the native call fails or the response cannot be parsed */ public ServerMetrics getMetricsTyped() { try { @@ -765,13 +782,13 @@ public ServerMetrics getMetricsTyped() { * Returns model metadata with typed accessors for vocab, context, embedding, * parameter count, size, and modality support flags (vision, audio). *

- * The returned {@link ModelMeta} wraps the raw JSON from the native layer. - * Call {@link ModelMeta#toString()} to re-serialize to compact JSON for use + * The returned {@link net.ladenthin.llama.value.ModelMeta} wraps the raw JSON from the native layer. + * Call {@link net.ladenthin.llama.value.ModelMeta#toString()} to re-serialize to compact JSON for use * in {@code assertEquals}. *

* - * @return {@link ModelMeta} parsed from the native {@code model_meta()} response - * @throws LlamaException if the native call fails or the response cannot be parsed + * @return {@link net.ladenthin.llama.value.ModelMeta} parsed from the native {@code model_meta()} response + * @throws net.ladenthin.llama.exception.LlamaException if the native call fails or the response cannot be parsed */ public ModelMeta getModelMeta() { try { diff --git a/src/main/java/net/ladenthin/llama/Session.java b/src/main/java/net/ladenthin/llama/Session.java index 13b3140a..fe8654ac 100644 --- a/src/main/java/net/ladenthin/llama/Session.java +++ b/src/main/java/net/ladenthin/llama/Session.java @@ -7,11 +7,14 @@ import java.util.List; import java.util.function.UnaryOperator; import lombok.ToString; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatTranscript; import org.jspecify.annotations.Nullable; /** * Thin multi-turn conversation wrapper over a {@link LlamaModel} slot. Maintains an - * accumulating list of {@link ChatMessage} turns and forwards each {@link #send(String)} + * accumulating list of {@link net.ladenthin.llama.value.ChatMessage} turns and forwards each {@link #send(String)} * to the underlying chat-completion API with the full transcript so far. KV-cache state * for the bound slot can be persisted via {@link #save(String)} and restored with * {@link #restore(String)}, which delegate to {@link LlamaModel#saveSlot(int, String)} @@ -46,7 +49,7 @@ public final class Session implements AutoCloseable { /** * Append-only transcript with two-phase commit semantics. See the - * {@link ChatTranscript} class Javadoc for the full invariant statement + * {@link net.ladenthin.llama.value.ChatTranscript} class Javadoc for the full invariant statement * and the {@code ChatTranscriptTest} class for the running-documentation * tests that pin the contract. */ @@ -76,8 +79,8 @@ public Session(LlamaModel model, int slotId, @Nullable String systemMessage) { /** * Create a session with a customizer that transforms the - * {@link InferenceParameters} for every call (e.g. {@code p -> p.withTemperature(0.7f).withNPredict(64)}). - * Because {@link InferenceParameters} is immutable, the customiser must return + * {@link net.ladenthin.llama.parameters.InferenceParameters} for every call (e.g. {@code p -> p.withTemperature(0.7f).withNPredict(64)}). + * Because {@link net.ladenthin.llama.parameters.InferenceParameters} is immutable, the customiser must return * the transformed instance — it cannot mutate the input. * * @param model the underlying model @@ -105,10 +108,9 @@ public Session( public String send(String userMessage) { synchronized (lock) { if (streamingActive) { - throw new IllegalStateException( - "stream in progress on slot " + slotId - + " (transcript=" + transcript.size() + " turns)" - + "; call commitStreamedReply(...) before send(...)"); + throw new IllegalStateException("stream in progress on slot " + slotId + + " (transcript=" + transcript.size() + " turns)" + + "; call commitStreamedReply(...) before send(...)"); } // Two-phase commit: build the wire-format with the pending user turn // outside the transcript via messagesWithPendingUserTurn(...). On @@ -134,10 +136,9 @@ public String send(String userMessage) { public LlamaIterable stream(String userMessage) { synchronized (lock) { if (streamingActive) { - throw new IllegalStateException( - "stream in progress on slot " + slotId - + " (transcript=" + transcript.size() + " turns)" - + "; call commitStreamedReply(...) before stream(...)"); + throw new IllegalStateException("stream in progress on slot " + slotId + + " (transcript=" + transcript.size() + " turns)" + + "; call commitStreamedReply(...) before stream(...)"); } // Two-phase commit: see send(). The user turn is committed only after // generateChat successfully returns the iterable; the assistant turn is @@ -158,10 +159,9 @@ public LlamaIterable stream(String userMessage) { public void commitStreamedReply(String assistantText) { synchronized (lock) { if (!streamingActive) { - throw new IllegalStateException( - "no stream in progress on slot " + slotId - + " (transcript=" + transcript.size() + " turns)" - + "; call stream(...) first"); + throw new IllegalStateException("no stream in progress on slot " + slotId + + " (transcript=" + transcript.size() + " turns)" + + "; call stream(...) first"); } transcript.appendAssistantTurn(assistantText); streamingActive = false; @@ -177,10 +177,9 @@ public void commitStreamedReply(String assistantText) { public String save(String filepath) { synchronized (lock) { if (streamingActive) { - throw new IllegalStateException( - "stream in progress on slot " + slotId - + " (transcript=" + transcript.size() + " turns)" - + "; call commitStreamedReply(...) before save(...)"); + throw new IllegalStateException("stream in progress on slot " + slotId + + " (transcript=" + transcript.size() + " turns)" + + "; call commitStreamedReply(...) before save(...)"); } return model.saveSlot(slotId, filepath); } @@ -195,10 +194,9 @@ public String save(String filepath) { public String restore(String filepath) { synchronized (lock) { if (streamingActive) { - throw new IllegalStateException( - "stream in progress on slot " + slotId - + " (transcript=" + transcript.size() + " turns)" - + "; call commitStreamedReply(...) before restore(...)"); + throw new IllegalStateException("stream in progress on slot " + slotId + + " (transcript=" + transcript.size() + " turns)" + + "; call commitStreamedReply(...) before restore(...)"); } return model.restoreSlot(slotId, filepath); } @@ -224,10 +222,10 @@ public void close() { /** * Build inference parameters with a pending user turn appended to the existing - * transcript — without mutating the underlying {@link ChatTranscript}. The + * transcript — without mutating the underlying {@link net.ladenthin.llama.value.ChatTranscript}. The * actual transcript mutation happens AFTER the model call returns successfully, - * either via {@link ChatTranscript#appendRound(String, String)} (send path) - * or {@link ChatTranscript#appendUserTurn(String)} (stream path). + * either via {@link net.ladenthin.llama.value.ChatTranscript#appendRound(String, String)} (send path) + * or {@link net.ladenthin.llama.value.ChatTranscript#appendUserTurn(String)} (stream path). * * @param pendingUserMessage the user turn to include in the wire format * @return inference parameters carrying transcript + pending user turn @@ -235,8 +233,7 @@ public void close() { private InferenceParameters buildParamsWithPendingUserTurn(String pendingUserMessage) { InferenceParameters params = InferenceParameters.empty() .withMessages( - transcript.getSystemMessage(), - transcript.messagesWithPendingUserTurn(pendingUserMessage)); + transcript.getSystemMessage(), transcript.messagesWithPendingUserTurn(pendingUserMessage)); return paramsCustomizer == null ? params : paramsCustomizer.apply(params); } } diff --git a/src/main/java/net/ladenthin/llama/args/ModelFlag.java b/src/main/java/net/ladenthin/llama/args/ModelFlag.java index af5807d5..93bf0d77 100644 --- a/src/main/java/net/ladenthin/llama/args/ModelFlag.java +++ b/src/main/java/net/ladenthin/llama/args/ModelFlag.java @@ -6,13 +6,13 @@ package net.ladenthin.llama.args; /** - * Boolean CLI flags for {@link net.ladenthin.llama.ModelParameters}. + * Boolean CLI flags for {@link net.ladenthin.llama.parameters.ModelParameters}. * *

Each constant maps to a single CLI argument that takes no value — its presence * alone enables the behaviour. Pass to - * {@link net.ladenthin.llama.ModelParameters#setFlag(ModelFlag)} / - * {@link net.ladenthin.llama.ModelParameters#clearFlag(ModelFlag)} for programmatic control, - * or use the named convenience methods (e.g. {@link net.ladenthin.llama.ModelParameters#enableFlashAttn()}). + * {@link net.ladenthin.llama.parameters.ModelParameters#setFlag(ModelFlag)} / + * {@link net.ladenthin.llama.parameters.ModelParameters#clearFlag(ModelFlag)} for programmatic control, + * or use the named convenience methods (e.g. {@link net.ladenthin.llama.parameters.ModelParameters#enableFlashAttn()}). */ public enum ModelFlag { @@ -117,7 +117,7 @@ public enum ModelFlag { * mismatch), upstream throws {@code common_skip_download_exception} during arg parsing, * which is caught inside {@code common_params_parse_ex} and surfaces as a {@code false} * return; the Java layer translates that combined signal into a typed - * {@link net.ladenthin.llama.ModelUnavailableException}.

+ * {@link net.ladenthin.llama.exception.ModelUnavailableException}.

*/ SKIP_DOWNLOAD("--skip-download"); diff --git a/src/main/java/net/ladenthin/llama/args/PoolingType.java b/src/main/java/net/ladenthin/llama/args/PoolingType.java index ce948029..8d78dd6c 100644 --- a/src/main/java/net/ladenthin/llama/args/PoolingType.java +++ b/src/main/java/net/ladenthin/llama/args/PoolingType.java @@ -6,7 +6,7 @@ package net.ladenthin.llama.args; /** - * Pooling strategy applied to token embeddings when {@link net.ladenthin.llama.ModelParameters#enableEmbedding()} + * Pooling strategy applied to token embeddings when {@link net.ladenthin.llama.parameters.ModelParameters#enableEmbedding()} * is active. * *

The string constants stored in each enum constant are the exact values accepted by the @@ -29,7 +29,7 @@ public enum PoolingType implements CliArg { * *

Maps to {@code LLAMA_POOLING_TYPE_UNSPECIFIED = -1} in {@code include/llama.h}. * This value has no corresponding CLI string; passing it to - * {@link net.ladenthin.llama.ModelParameters#setPoolingType(PoolingType)} intentionally + * {@link net.ladenthin.llama.parameters.ModelParameters#setPoolingType(PoolingType)} intentionally * omits the {@code --pooling} flag so llama.cpp chooses the pooling strategy itself. */ UNSPECIFIED("unspecified"), @@ -68,7 +68,7 @@ public enum PoolingType implements CliArg { /** * Rank pooling – used by re-ranking models to produce a relevance score. - * Requires a model loaded with {@link net.ladenthin.llama.ModelParameters#enableReranking()}; + * Requires a model loaded with {@link net.ladenthin.llama.parameters.ModelParameters#enableReranking()}; * not applicable to plain embedding models. * *

CLI string: {@code "rank"} — maps to {@code LLAMA_POOLING_TYPE_RANK = 4}. diff --git a/src/main/java/net/ladenthin/llama/args/ReasoningFormat.java b/src/main/java/net/ladenthin/llama/args/ReasoningFormat.java index 84d2fba3..d46791fd 100644 --- a/src/main/java/net/ladenthin/llama/args/ReasoningFormat.java +++ b/src/main/java/net/ladenthin/llama/args/ReasoningFormat.java @@ -11,7 +11,7 @@ * *

Passed as {@code "reasoning_format"} in inference requests. Only meaningful when the model * uses a thinking tag (e.g. {@code ...}) and chat-template rendering is active - * ({@link net.ladenthin.llama.InferenceParameters#withUseChatTemplate(boolean)}). + * ({@link net.ladenthin.llama.parameters.InferenceParameters#withUseChatTemplate(boolean)}). */ public enum ReasoningFormat implements CliArg { diff --git a/src/main/java/net/ladenthin/llama/args/package-info.java b/src/main/java/net/ladenthin/llama/args/package-info.java index 18542d5e..a8438ff5 100644 --- a/src/main/java/net/ladenthin/llama/args/package-info.java +++ b/src/main/java/net/ladenthin/llama/args/package-info.java @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT /** - * Typed enums for CLI-arg-valued options consumed by {@link net.ladenthin.llama.CliParameters}. + * Typed enums for CLI-arg-valued options consumed by {@link net.ladenthin.llama.parameters.CliParameters}. * *

JSpecify {@code @NullMarked} is declared at module level in * {@code module-info.java} and applies to this package transitively. diff --git a/src/main/java/net/ladenthin/llama/CancellationToken.java b/src/main/java/net/ladenthin/llama/callback/CancellationToken.java similarity index 89% rename from src/main/java/net/ladenthin/llama/CancellationToken.java rename to src/main/java/net/ladenthin/llama/callback/CancellationToken.java index 5cf25929..70365a62 100644 --- a/src/main/java/net/ladenthin/llama/CancellationToken.java +++ b/src/main/java/net/ladenthin/llama/callback/CancellationToken.java @@ -2,13 +2,14 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; import lombok.ToString; +import net.ladenthin.llama.parameters.InferenceParameters; /** - * Cancellation handle for a blocking {@link LlamaModel} call. Pass an instance to - * {@link LlamaModel#complete(InferenceParameters, CancellationToken)} and invoke + * Cancellation handle for a blocking {@link net.ladenthin.llama.LlamaModel} call. Pass an instance to + * {@link net.ladenthin.llama.LlamaModel#complete(InferenceParameters, CancellationToken)} and invoke * {@link #cancel()} from another thread to abort the inference loop. *

* Cancellation is cooperative: {@link #cancel()} only sets a flag, and the inference @@ -66,7 +67,7 @@ public void cancel() { } /** Clear the cancelled flag so the token can be reused. Package-private. */ - void reset() { + public void reset() { cancelled = false; } } diff --git a/src/main/java/net/ladenthin/llama/LoadProgressCallback.java b/src/main/java/net/ladenthin/llama/callback/LoadProgressCallback.java similarity index 71% rename from src/main/java/net/ladenthin/llama/LoadProgressCallback.java rename to src/main/java/net/ladenthin/llama/callback/LoadProgressCallback.java index 15e02900..8eae1611 100644 --- a/src/main/java/net/ladenthin/llama/LoadProgressCallback.java +++ b/src/main/java/net/ladenthin/llama/callback/LoadProgressCallback.java @@ -2,19 +2,21 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; + +import net.ladenthin.llama.parameters.ModelParameters; /** * Receives model-load progress updates from the native loader. *

- * Pass an instance to {@link LlamaModel#LlamaModel(ModelParameters, LoadProgressCallback)} + * Pass an instance to {@link net.ladenthin.llama.LlamaModel#LlamaModel(ModelParameters, LoadProgressCallback)} * to observe the {@code llama_model_params.progress_callback} hook from llama.cpp. The * callback is invoked synchronously on the loader thread (the same thread that called * the constructor) with a value in {@code [0.0, 1.0]}. *

*

* Return {@code false} to abort the load. When {@code false} is returned, the constructor - * throws {@link LlamaException} because the native loader aborts and reports failure. + * throws {@link net.ladenthin.llama.exception.LlamaException} because the native loader aborts and reports failure. *

*/ @FunctionalInterface diff --git a/src/main/java/net/ladenthin/llama/ToolHandler.java b/src/main/java/net/ladenthin/llama/callback/ToolHandler.java similarity index 84% rename from src/main/java/net/ladenthin/llama/ToolHandler.java rename to src/main/java/net/ladenthin/llama/callback/ToolHandler.java index a0850484..9546b9a0 100644 --- a/src/main/java/net/ladenthin/llama/ToolHandler.java +++ b/src/main/java/net/ladenthin/llama/callback/ToolHandler.java @@ -2,11 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; + +import net.ladenthin.llama.parameters.ChatRequest; /** * Invocation contract for a tool registered with - * {@link LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. + * {@link net.ladenthin.llama.LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. *

* The handler receives the model-supplied arguments as a JSON string and returns the * tool's output as a JSON string (an unwrapped string literal also works). Exceptions diff --git a/src/main/java/net/ladenthin/llama/callback/package-info.java b/src/main/java/net/ladenthin/llama/callback/package-info.java new file mode 100644 index 00000000..03aef2fb --- /dev/null +++ b/src/main/java/net/ladenthin/llama/callback/package-info.java @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +/** + * Functional-interface callbacks and cancellation tokens implemented by callers. + * + *

JSpecify {@code @NullMarked} is applied module-wide; everything is non-null + * unless annotated {@code @Nullable}. + */ +package net.ladenthin.llama.callback; diff --git a/src/main/java/net/ladenthin/llama/LlamaException.java b/src/main/java/net/ladenthin/llama/exception/LlamaException.java similarity index 96% rename from src/main/java/net/ladenthin/llama/LlamaException.java rename to src/main/java/net/ladenthin/llama/exception/LlamaException.java index ebc3c864..457765e6 100644 --- a/src/main/java/net/ladenthin/llama/LlamaException.java +++ b/src/main/java/net/ladenthin/llama/exception/LlamaException.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.exception; /** * Base unchecked exception raised by the JNI layer when a llama.cpp operation diff --git a/src/main/java/net/ladenthin/llama/ModelUnavailableException.java b/src/main/java/net/ladenthin/llama/exception/ModelUnavailableException.java similarity index 81% rename from src/main/java/net/ladenthin/llama/ModelUnavailableException.java rename to src/main/java/net/ladenthin/llama/exception/ModelUnavailableException.java index dfa57fad..ab5a70d2 100644 --- a/src/main/java/net/ladenthin/llama/ModelUnavailableException.java +++ b/src/main/java/net/ladenthin/llama/exception/ModelUnavailableException.java @@ -2,13 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.exception; -import net.ladenthin.llama.args.ModelFlag; +import net.ladenthin.llama.parameters.ModelParameters; /** - * Thrown by {@link LlamaModel#LlamaModel(ModelParameters)} when - * {@link ModelFlag#SKIP_DOWNLOAD} (or {@link ModelParameters#setSkipDownload(boolean) + * Thrown by {@link net.ladenthin.llama.LlamaModel#LlamaModel(ModelParameters)} when + * {@link net.ladenthin.llama.args.ModelFlag#SKIP_DOWNLOAD} (or {@link net.ladenthin.llama.parameters.ModelParameters#setSkipDownload(boolean) * setSkipDownload(true)}) is set and the configured model file is missing or * invalid — i.e. the loader would have had to download a replacement but is * forbidden to. diff --git a/src/main/java/net/ladenthin/llama/exception/package-info.java b/src/main/java/net/ladenthin/llama/exception/package-info.java new file mode 100644 index 00000000..a01a5663 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/exception/package-info.java @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +/** + * Library exception hierarchy. + * + *

JSpecify {@code @NullMarked} is applied module-wide; everything is non-null + * unless annotated {@code @Nullable}. + */ +package net.ladenthin.llama.exception; diff --git a/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java b/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java index 8508d349..72d2dd44 100644 --- a/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java +++ b/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java @@ -11,13 +11,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import net.ladenthin.llama.ChatChoice; -import net.ladenthin.llama.ChatMessage; -import net.ladenthin.llama.ChatResponse; -import net.ladenthin.llama.Timings; -import net.ladenthin.llama.TimingsLogger; -import net.ladenthin.llama.ToolCall; -import net.ladenthin.llama.Usage; +import net.ladenthin.llama.value.ChatChoice; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.Timings; +import net.ladenthin.llama.value.ToolCall; +import net.ladenthin.llama.value.Usage; /** * Pure JSON transforms for OAI-compatible chat completion responses. @@ -139,12 +138,12 @@ public int countChoices(JsonNode node) { } /** - * Parse a full OAI chat completion JSON string into a typed {@link ChatResponse}. - * Carries the {@code id}, choices, {@link Usage}, and {@link Timings}. The original - * JSON is preserved on {@link ChatResponse#getRawJson()}. + * Parse a full OAI chat completion JSON string into a typed {@link net.ladenthin.llama.value.ChatResponse}. + * Carries the {@code id}, choices, {@link net.ladenthin.llama.value.Usage}, and {@link net.ladenthin.llama.value.Timings}. The original + * JSON is preserved on {@link net.ladenthin.llama.value.ChatResponse#getRawJson()}. * * @param json the OAI-compatible chat completion JSON string - * @return a parsed {@link ChatResponse} (empty choices on malformed input) + * @return a parsed {@link net.ladenthin.llama.value.ChatResponse} (empty choices on malformed input) */ public ChatResponse parseResponse(String json) { try { @@ -164,36 +163,40 @@ public ChatResponse parseResponse(String json) { } private List parseChoices(JsonNode arr) { - // Mutable ArrayList on both branches keeps the return-type contract consistent - // (Error Prone MixedMutabilityReturnType). - if (!arr.isArray() || arr.size() == 0) return new ArrayList<>(); - List out = new ArrayList(arr.size()); - for (JsonNode c : arr) { - int index = c.path("index").asInt(0); - JsonNode msg = c.path("message"); - String role = msg.path("role").asText("assistant"); - String content = msg.path("content").asText(""); - List toolCalls = parseToolCalls(msg.path("tool_calls")); - ChatMessage message = toolCalls.isEmpty() - ? new ChatMessage(role, content) - : ChatMessage.assistantToolCalls(content, toolCalls); - String finishReason = c.path("finish_reason").asText(""); - out.add(new ChatChoice(index, message, finishReason)); + // Single mutable-ArrayList return: an empty (or non-array) input falls + // through the loop and returns the same empty ArrayList, keeping the + // return-type contract consistent (Error Prone MixedMutabilityReturnType) + // and leaving no equivalent empty-branch mutant for PIT to flag. + List out = new ArrayList<>(); + if (arr.isArray()) { + for (JsonNode c : arr) { + int index = c.path("index").asInt(0); + JsonNode msg = c.path("message"); + String role = msg.path("role").asText("assistant"); + String content = msg.path("content").asText(""); + List toolCalls = parseToolCalls(msg.path("tool_calls")); + ChatMessage message = toolCalls.isEmpty() + ? new ChatMessage(role, content) + : ChatMessage.assistantToolCalls(content, toolCalls); + String finishReason = c.path("finish_reason").asText(""); + out.add(new ChatChoice(index, message, finishReason)); + } } return out; } private List parseToolCalls(JsonNode arr) { - if (!arr.isArray() || arr.size() == 0) return new ArrayList<>(); - List out = new ArrayList(arr.size()); - for (JsonNode tc : arr) { - String id = tc.path("id").asText(""); - JsonNode fn = tc.path("function"); - String name = fn.path("name").asText(""); - JsonNode argsNode = fn.path("arguments"); - // OAI emits arguments as a string; some shapes emit a nested object. - String args = argsNode.isTextual() ? argsNode.asText("") : argsNode.toString(); - out.add(new ToolCall(id, name, args)); + List out = new ArrayList<>(); + if (arr.isArray()) { + for (JsonNode tc : arr) { + String id = tc.path("id").asText(""); + JsonNode fn = tc.path("function"); + String name = fn.path("name").asText(""); + JsonNode argsNode = fn.path("arguments"); + // OAI emits arguments as a string; some shapes emit a nested object. + String args = argsNode.isTextual() ? argsNode.asText("") : argsNode.toString(); + out.add(new ToolCall(id, name, args)); + } } return out; } diff --git a/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java b/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java index c6027375..b0ce96b0 100644 --- a/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java +++ b/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java @@ -13,14 +13,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import net.ladenthin.llama.CompletionResult; -import net.ladenthin.llama.InferenceParameters; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.StopReason; -import net.ladenthin.llama.Timings; -import net.ladenthin.llama.TimingsLogger; -import net.ladenthin.llama.TokenLogprob; -import net.ladenthin.llama.Usage; +import net.ladenthin.llama.value.CompletionResult; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.StopReason; +import net.ladenthin.llama.value.Timings; +import net.ladenthin.llama.value.TokenLogprob; +import net.ladenthin.llama.value.Usage; /** * Pure JSON transforms for native completion/streaming responses. @@ -39,7 +37,7 @@ * } * } * - *

When inference is configured with {@link InferenceParameters#withNProbs(int)} > 0, + *

When inference is configured with {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} > 0, * each chunk additionally carries a {@code completion_probabilities} array: *

{@code
  * {
@@ -63,12 +61,12 @@ public CompletionResponseParser() {}
     public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
 
     /**
-     * Parse a {@link LlamaOutput} from a raw JSON string returned by the native
+     * Parse a {@link net.ladenthin.llama.value.LlamaOutput} from a raw JSON string returned by the native
      * {@code receiveCompletionJson} method. Delegates to {@link #parse(JsonNode)} after
      * a single {@code readTree} call so the string is parsed only once.
      *
      * @param json raw JSON string from the native completion response
-     * @return parsed {@link LlamaOutput}; empty output on parse failure
+     * @return parsed {@link net.ladenthin.llama.value.LlamaOutput}; empty output on parse failure
      */
     public LlamaOutput parse(String json) {
         try {
@@ -84,11 +82,11 @@ public LlamaOutput parse(String json) {
     }
 
     /**
-     * Parse a {@link LlamaOutput} from a pre-parsed {@link JsonNode}.
+     * Parse a {@link net.ladenthin.llama.value.LlamaOutput} from a pre-parsed {@link JsonNode}.
      * Callers that already hold a parsed node should prefer this overload to avoid re-parsing.
      *
      * @param node pre-parsed completion response node
-     * @return parsed {@link LlamaOutput}
+     * @return parsed {@link net.ladenthin.llama.value.LlamaOutput}
      */
     public LlamaOutput parse(JsonNode node) {
         String content = extractContent(node);
@@ -148,41 +146,42 @@ public Map parseProbabilities(JsonNode root) {
     }
 
     /**
-     * Parse the {@code completion_probabilities} array into a list of typed {@link TokenLogprob}
+     * Parse the {@code completion_probabilities} array into a list of typed {@link net.ladenthin.llama.value.TokenLogprob}
      * entries, preserving order, token ids, and the nested alternatives array
      * ({@code top_probs} for post-sampling mode or {@code top_logprobs} for pre-sampling).
      *
      * 

Returns an empty list when the field is absent or empty. Requires - * {@link InferenceParameters#withNProbs(int)} to be configured. + * {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} to be configured. * * @param root the top-level completion response node - * @return list of {@link TokenLogprob}; empty when no probability data is present + * @return list of {@link net.ladenthin.llama.value.TokenLogprob}; empty when no probability data is present */ public List parseLogprobs(JsonNode root) { JsonNode array = root.path("completion_probabilities"); - if (!array.isArray() || array.size() == 0) { - // Return a mutable empty ArrayList to keep the return type consistent - // with the non-empty branch below (Error Prone MixedMutabilityReturnType). - return new ArrayList<>(); - } - List result = new ArrayList(array.size()); - for (JsonNode entry : array) { - result.add(parseLogprobEntry(entry)); + // Single mutable-ArrayList return: an empty (or absent) array falls + // through the loop and returns the same empty ArrayList, keeping the + // return type consistent (Error Prone MixedMutabilityReturnType) and + // leaving no equivalent empty-branch mutant for PIT to flag. + List result = new ArrayList<>(); + if (array.isArray()) { + for (JsonNode entry : array) { + result.add(parseLogprobEntry(entry)); + } } return result; } /** - * Parse a {@link CompletionResult} from the non-streaming, non-OAI completion JSON + * Parse a {@link net.ladenthin.llama.value.CompletionResult} from the non-streaming, non-OAI completion JSON * emitted by {@code server_task_result_cmpl_final::to_json_non_oaicompat}. *

* Maps {@code content} → text, {@code tokens_evaluated}/{@code tokens_predicted} → - * {@link Usage}, the {@code timings} sub-object → {@link Timings}, - * {@code completion_probabilities} → {@link TokenLogprob} list, and - * {@code stop_type} → {@link StopReason}. + * {@link net.ladenthin.llama.value.Usage}, the {@code timings} sub-object → {@link net.ladenthin.llama.value.Timings}, + * {@code completion_probabilities} → {@link net.ladenthin.llama.value.TokenLogprob} list, and + * {@code stop_type} → {@link net.ladenthin.llama.value.StopReason}. * * @param json raw JSON string from the native completion response - * @return a populated {@link CompletionResult}; fields default to empty/zero on parse failure + * @return a populated {@link net.ladenthin.llama.value.CompletionResult}; fields default to empty/zero on parse failure */ public CompletionResult parseCompletionResult(String json) { try { @@ -221,14 +220,14 @@ private TokenLogprob parseLogprobEntry(JsonNode entry) { if (!top.isArray()) { top = entry.path("top_logprobs"); } - List topLogprobs; - if (top.isArray() && top.size() > 0) { - topLogprobs = new ArrayList(top.size()); + // Single mutable-ArrayList accumulation: a missing or empty nested array + // skips the loop and yields an empty ArrayList, so there is no equivalent + // empty-branch mutant (the prior emptyList()/ArrayList ternary left one). + List topLogprobs = new ArrayList<>(); + if (top.isArray()) { for (JsonNode t : top) { topLogprobs.add(parseLogprobEntry(t)); } - } else { - topLogprobs = Collections.emptyList(); } return new TokenLogprob(token, tokenId, logprob, topLogprobs); } diff --git a/src/main/java/net/ladenthin/llama/json/RerankResponseParser.java b/src/main/java/net/ladenthin/llama/json/RerankResponseParser.java index 346e4c5b..40d0754b 100644 --- a/src/main/java/net/ladenthin/llama/json/RerankResponseParser.java +++ b/src/main/java/net/ladenthin/llama/json/RerankResponseParser.java @@ -11,7 +11,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import net.ladenthin.llama.Pair; +import net.ladenthin.llama.value.Pair; /** * Pure JSON transforms for native rerank responses. diff --git a/src/main/java/net/ladenthin/llama/TimingsLogger.java b/src/main/java/net/ladenthin/llama/json/TimingsLogger.java similarity index 95% rename from src/main/java/net/ladenthin/llama/TimingsLogger.java rename to src/main/java/net/ladenthin/llama/json/TimingsLogger.java index ad34b6a4..8327d64a 100644 --- a/src/main/java/net/ladenthin/llama/TimingsLogger.java +++ b/src/main/java/net/ladenthin/llama/json/TimingsLogger.java @@ -1,9 +1,10 @@ // SPDX-FileCopyrightText: 2026 Bernard Ladenthin // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.json; import java.util.Locale; +import net.ladenthin.llama.value.Timings; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -18,7 +19,7 @@ *

* *

Speculative-decoding runs append a {@code | draft: N (M accepted)} segment. - * Empty {@link Timings} (both {@code promptN} and {@code predictedN} zero) are + * Empty {@link net.ladenthin.llama.value.Timings} (both {@code promptN} and {@code predictedN} zero) are * skipped — logging the all-zero fallback on a parse failure or on early * cancellation is pure noise.

* diff --git a/src/main/java/net/ladenthin/llama/Java8CompatibilityHelper.java b/src/main/java/net/ladenthin/llama/loader/Java8CompatibilityHelper.java similarity index 99% rename from src/main/java/net/ladenthin/llama/Java8CompatibilityHelper.java rename to src/main/java/net/ladenthin/llama/loader/Java8CompatibilityHelper.java index 9a8dfba5..1fee4e0c 100644 --- a/src/main/java/net/ladenthin/llama/Java8CompatibilityHelper.java +++ b/src/main/java/net/ladenthin/llama/loader/Java8CompatibilityHelper.java @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2026 Bernard Ladenthin // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import java.io.ByteArrayOutputStream; import java.io.IOException; diff --git a/src/main/java/net/ladenthin/llama/LlamaLoader.java b/src/main/java/net/ladenthin/llama/loader/LlamaLoader.java similarity index 93% rename from src/main/java/net/ladenthin/llama/LlamaLoader.java rename to src/main/java/net/ladenthin/llama/loader/LlamaLoader.java index 2c96b0e2..30882ba9 100644 --- a/src/main/java/net/ladenthin/llama/LlamaLoader.java +++ b/src/main/java/net/ladenthin/llama/loader/LlamaLoader.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import java.io.BufferedInputStream; import java.io.File; @@ -42,16 +42,28 @@ */ @SuppressWarnings("UseOfSystemOutOrSystemErr") @ToString -class LlamaLoader { +public class LlamaLoader { private static boolean extracted = false; private static final LlamaSystemProperties systemProperties = new LlamaSystemProperties(); private static final NativeLibraryPermissionSetter permissionSetter = new NativeLibraryPermissionSetter(System.err); + /** + * Canonical classpath root for the bundled native libraries. Fixed by + * {@code CMakeLists.txt} and the publish workflow (both emit to + * {@code resources/net/ladenthin/llama///}); it must NOT be + * derived from this loader's own Java package, which moved to + * {@code net.ladenthin.llama.loader} during the layered restructure. + */ + private static final String NATIVE_RESOURCE_BASE = "/net/ladenthin/llama"; + + /** Static utility holder; not instantiable. */ + private LlamaLoader() {} + /** * Loads the llama and jllama shared libraries */ - static synchronized void initialize() { + public static synchronized void initialize() { // only cleanup before the first extract if (!extracted) { cleanup(); @@ -264,15 +276,7 @@ static File getTempDir() { } static String getNativeResourcePath() { - final Package pkg = LlamaLoader.class.getPackage(); - // LlamaLoader is in a named package, so Class.getPackage() is never null here. - if (pkg == null) { - throw new IllegalStateException( - "LlamaLoader.class.getPackage() returned null (classLoader=" - + LlamaLoader.class.getClassLoader() + ")"); - } - String packagePath = pkg.getName().replace('.', '/'); - return String.format("/%s/%s", packagePath, OSInfo.getNativeLibFolderPathForCurrentOS()); + return String.format("%s/%s", NATIVE_RESOURCE_BASE, OSInfo.getNativeLibFolderPathForCurrentOS()); } private static boolean hasNativeLib(String path, String libraryName) { diff --git a/src/main/java/net/ladenthin/llama/LlamaSystemProperties.java b/src/main/java/net/ladenthin/llama/loader/LlamaSystemProperties.java similarity index 98% rename from src/main/java/net/ladenthin/llama/LlamaSystemProperties.java rename to src/main/java/net/ladenthin/llama/loader/LlamaSystemProperties.java index 30123ab6..52067d88 100644 --- a/src/main/java/net/ladenthin/llama/LlamaSystemProperties.java +++ b/src/main/java/net/ladenthin/llama/loader/LlamaSystemProperties.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import lombok.ToString; import org.jspecify.annotations.Nullable; diff --git a/src/main/java/net/ladenthin/llama/NativeLibraryPermissionSetter.java b/src/main/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetter.java similarity index 97% rename from src/main/java/net/ladenthin/llama/NativeLibraryPermissionSetter.java rename to src/main/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetter.java index db73268a..e2c6e1bb 100644 --- a/src/main/java/net/ladenthin/llama/NativeLibraryPermissionSetter.java +++ b/src/main/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetter.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import java.io.File; import java.io.PrintStream; diff --git a/src/main/java/net/ladenthin/llama/OSInfo.java b/src/main/java/net/ladenthin/llama/loader/OSInfo.java similarity index 99% rename from src/main/java/net/ladenthin/llama/OSInfo.java rename to src/main/java/net/ladenthin/llama/loader/OSInfo.java index b0c3d83e..645bdc21 100644 --- a/src/main/java/net/ladenthin/llama/OSInfo.java +++ b/src/main/java/net/ladenthin/llama/loader/OSInfo.java @@ -70,7 +70,7 @@ // $URL$ // $Author$ // -------------------------------------- -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import java.io.File; import java.io.IOException; diff --git a/src/main/java/net/ladenthin/llama/ProcessRunner.java b/src/main/java/net/ladenthin/llama/loader/ProcessRunner.java similarity index 98% rename from src/main/java/net/ladenthin/llama/ProcessRunner.java rename to src/main/java/net/ladenthin/llama/loader/ProcessRunner.java index 1f783b81..ea12d050 100644 --- a/src/main/java/net/ladenthin/llama/ProcessRunner.java +++ b/src/main/java/net/ladenthin/llama/loader/ProcessRunner.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import java.io.ByteArrayOutputStream; import java.io.IOException; diff --git a/src/main/java/net/ladenthin/llama/SkipDownloadFailureTranslator.java b/src/main/java/net/ladenthin/llama/loader/SkipDownloadFailureTranslator.java similarity index 68% rename from src/main/java/net/ladenthin/llama/SkipDownloadFailureTranslator.java rename to src/main/java/net/ladenthin/llama/loader/SkipDownloadFailureTranslator.java index 3c6ec985..48702b10 100644 --- a/src/main/java/net/ladenthin/llama/SkipDownloadFailureTranslator.java +++ b/src/main/java/net/ladenthin/llama/loader/SkipDownloadFailureTranslator.java @@ -2,17 +2,20 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import net.ladenthin.llama.args.ModelFlag; +import net.ladenthin.llama.exception.LlamaException; +import net.ladenthin.llama.exception.ModelUnavailableException; +import net.ladenthin.llama.parameters.ModelParameters; /** - * Pure-Java translator from the generic {@link LlamaException} raised by the JNI - * loader to the typed {@link ModelUnavailableException} when - * {@link ModelFlag#SKIP_DOWNLOAD} is set and the load failed because the + * Pure-Java translator from the generic {@link net.ladenthin.llama.exception.LlamaException} raised by the JNI + * loader to the typed {@link net.ladenthin.llama.exception.ModelUnavailableException} when + * {@link net.ladenthin.llama.args.ModelFlag#SKIP_DOWNLOAD} is set and the load failed because the * configured model file was missing or invalid. * - *

Lives outside {@link LlamaModel} so that unit tests can exercise the + *

Lives outside {@link net.ladenthin.llama.LlamaModel} so that unit tests can exercise the * translation heuristic without triggering {@code LlamaModel}'s * {@link LlamaLoader} static initializer (which loads the JNI library and is * not available in CPU-only / non-native test environments).

@@ -25,12 +28,12 @@ * INSIDE upstream's own {@code common_params_parse_ex} (at * {@code common/arg.cpp:476}) and surfaces only as a {@code false} return * from {@code common_params_parse}. The JNI layer reports the {@code false} - * return as a generic {@link LlamaException} with the message + * return as a generic {@link net.ladenthin.llama.exception.LlamaException} with the message * {@value #LOAD_PARSE_FAILED_MESSAGE}. The Java layer therefore cannot catch * the C++ exception directly and instead recognises the combined signal: * {@code SKIP_DOWNLOAD} flag set + JNI message matches.

*/ -final class SkipDownloadFailureTranslator { +public final class SkipDownloadFailureTranslator { /** * Substring used by the JNI bridge when {@code common_params_parse} returns @@ -45,17 +48,17 @@ private SkipDownloadFailureTranslator() { /** * Translates a generic load failure into a typed - * {@link ModelUnavailableException} when the user opted into - * {@link ModelFlag#SKIP_DOWNLOAD} and the JNI surfaced the + * {@link net.ladenthin.llama.exception.ModelUnavailableException} when the user opted into + * {@link net.ladenthin.llama.args.ModelFlag#SKIP_DOWNLOAD} and the JNI surfaced the * {@value #LOAD_PARSE_FAILED_MESSAGE} message; otherwise returns the * original exception unchanged so the caller can re-throw it as-is. * * @param parameters the parameters passed to the failing constructor * @param original the original load failure to translate or pass through - * @return a {@link ModelUnavailableException} when the heuristic matches; + * @return a {@link net.ladenthin.llama.exception.ModelUnavailableException} when the heuristic matches; * otherwise the original {@code LlamaException} */ - static LlamaException translate(ModelParameters parameters, LlamaException original) { + public static LlamaException translate(ModelParameters parameters, LlamaException original) { if (parameters.hasFlag(ModelFlag.SKIP_DOWNLOAD) && original.getMessage() != null && original.getMessage().contains(LOAD_PARSE_FAILED_MESSAGE)) { diff --git a/src/main/java/net/ladenthin/llama/loader/package-info.java b/src/main/java/net/ladenthin/llama/loader/package-info.java new file mode 100644 index 00000000..3a3c93b1 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/loader/package-info.java @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +/** + * Native-library loading, OS/architecture detection and process/system-property infrastructure. + * + *

JSpecify {@code @NullMarked} is applied module-wide; everything is non-null + * unless annotated {@code @Nullable}. + */ +package net.ladenthin.llama.loader; diff --git a/src/main/java/net/ladenthin/llama/ChatRequest.java b/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java similarity index 93% rename from src/main/java/net/ladenthin/llama/ChatRequest.java rename to src/main/java/net/ladenthin/llama/parameters/ChatRequest.java index 0d1cce7d..23173bc3 100644 --- a/src/main/java/net/ladenthin/llama/ChatRequest.java +++ b/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; @@ -15,6 +15,9 @@ import java.util.function.UnaryOperator; import lombok.EqualsAndHashCode; import lombok.ToString; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ToolCall; +import net.ladenthin.llama.value.ToolDefinition; import org.jspecify.annotations.Nullable; /** @@ -30,8 +33,8 @@ * {@link UnaryOperator} that takes a parameter set and returns the transformed * one — callers chain {@code withX(...)} calls on the input and return the * resulting instance. The type is consumed by - * {@link LlamaModel#chat(ChatRequest)} and - * {@link LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. + * {@link net.ladenthin.llama.LlamaModel#chat(ChatRequest)} and + * {@link net.ladenthin.llama.LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. * *

All instances are immutable: every field is {@code final} and the * stored lists are wrapped with {@link Collections#unmodifiableList(List)}. @@ -142,17 +145,12 @@ public ChatRequest appendMessage(ChatMessage message) { List next = new ArrayList(messages.size() + 1); next.addAll(messages); next.add(message); - return new ChatRequest( - Collections.unmodifiableList(next), - tools, - toolChoice, - maxToolRounds, - paramsCustomizer); + return new ChatRequest(Collections.unmodifiableList(next), tools, toolChoice, maxToolRounds, paramsCustomizer); } /** * Convenience for {@link #appendMessage(ChatMessage)} that wraps a role + - * content pair into a new {@link ChatMessage} and appends it. + * content pair into a new {@link net.ladenthin.llama.value.ChatMessage} and appends it. * * @param role the role (e.g. {@code "system"}, {@code "user"}, {@code "assistant"}) * @param content the message content @@ -173,11 +171,7 @@ public ChatRequest appendTool(ToolDefinition tool) { next.addAll(tools); next.add(tool); return new ChatRequest( - messages, - Collections.unmodifiableList(next), - toolChoice, - maxToolRounds, - paramsCustomizer); + messages, Collections.unmodifiableList(next), toolChoice, maxToolRounds, paramsCustomizer); } // ----------------------------------------------------------------------- @@ -204,8 +198,7 @@ public ChatRequest withToolChoice(@Nullable String newToolChoice) { */ public ChatRequest withMaxToolRounds(int newMaxToolRounds) { if (newMaxToolRounds <= 0) { - throw new IllegalArgumentException( - "maxToolRounds must be > 0 but was " + newMaxToolRounds); + throw new IllegalArgumentException("maxToolRounds must be > 0 but was " + newMaxToolRounds); } return new ChatRequest(messages, tools, toolChoice, newMaxToolRounds, paramsCustomizer); } @@ -324,13 +317,13 @@ public Optional buildToolsJson() { /** * Apply the optional customiser to an {@link InferenceParameters} instance and - * return the transformed result. Package-private; called by {@link LlamaModel}. + * return the transformed result. Package-private; called by {@link net.ladenthin.llama.LlamaModel}. * When no customiser is set, returns {@code params} unchanged. * * @param params the parameters to transform * @return the (possibly new) parameters produced by the customiser, or {@code params} when no customiser is set */ - InferenceParameters applyCustomizer(InferenceParameters params) { + public InferenceParameters applyCustomizer(InferenceParameters params) { return paramsCustomizer == null ? params : paramsCustomizer.apply(params); } } diff --git a/src/main/java/net/ladenthin/llama/CliParameters.java b/src/main/java/net/ladenthin/llama/parameters/CliParameters.java similarity index 96% rename from src/main/java/net/ladenthin/llama/CliParameters.java rename to src/main/java/net/ladenthin/llama/parameters/CliParameters.java index 941e2c81..8940d881 100644 --- a/src/main/java/net/ladenthin/llama/CliParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/CliParameters.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import java.util.ArrayList; import java.util.HashMap; @@ -52,7 +52,7 @@ protected final T putScalar(String key, Object value) * return this builder typed as the concrete subtype. * * @param key the parameter key - * @param value the enum constant; must implement {@link CliArg} + * @param value the enum constant; must implement {@link net.ladenthin.llama.args.CliArg} * @param the concrete subtype of this builder * @return this builder */ diff --git a/src/main/java/net/ladenthin/llama/InferenceParameters.java b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java similarity index 94% rename from src/main/java/net/ladenthin/llama/InferenceParameters.java rename to src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java index 46a8d415..824965de 100644 --- a/src/main/java/net/ladenthin/llama/InferenceParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import java.util.Collection; import java.util.Collections; @@ -15,12 +15,14 @@ import net.ladenthin.llama.args.MiroStat; import net.ladenthin.llama.args.ReasoningFormat; import net.ladenthin.llama.args.Sampler; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.Pair; import org.jspecify.annotations.Nullable; /** - * Immutable typed parameters for {@link LlamaModel} inference calls - * ({@link LlamaModel#generate(InferenceParameters)}, - * {@link LlamaModel#complete(InferenceParameters)}, etc.), populated through a + * Immutable typed parameters for {@link net.ladenthin.llama.LlamaModel} inference calls + * ({@link net.ladenthin.llama.LlamaModel#generate(InferenceParameters)}, + * {@link net.ladenthin.llama.LlamaModel#complete(InferenceParameters)}, etc.), populated through a * functional {@code withX(...)} API. * *

Design

@@ -123,7 +125,7 @@ private static Map singletonPrompt(String prompt) { // Mirror the JSON-encoding path used by withOptionalJson so toString() output // is byte-identical between `new InferenceParameters(p)` and `of(p)`. Map m = new HashMap<>(); - m.put(PARAM_PROMPT, new net.ladenthin.llama.json.ParameterJsonSerializer().toJsonString(prompt)); + m.put(PARAM_PROMPT, new net.ladenthin.llama.parameters.ParameterJsonSerializer().toJsonString(prompt)); return Collections.unmodifiableMap(m); } @@ -481,7 +483,8 @@ public InferenceParameters withTokenIdBias(Map logitBias) { if (logitBias.isEmpty()) { return this; } - return withRaw(PARAM_LOGIT_BIAS, serializer.buildTokenIdBiasArray(logitBias).toString()); + return withRaw( + PARAM_LOGIT_BIAS, serializer.buildTokenIdBiasArray(logitBias).toString()); } /** @@ -496,7 +499,8 @@ public InferenceParameters withDisabledTokenIds(Collection tokenIds) { if (tokenIds.isEmpty()) { return this; } - return withRaw(PARAM_LOGIT_BIAS, serializer.buildDisableTokenIdArray(tokenIds).toString()); + return withRaw( + PARAM_LOGIT_BIAS, serializer.buildDisableTokenIdArray(tokenIds).toString()); } /** @@ -510,7 +514,9 @@ public InferenceParameters withTokenBias(Map logitBias) { if (logitBias.isEmpty()) { return this; } - return withRaw(PARAM_LOGIT_BIAS, serializer.buildTokenStringBiasArray(logitBias).toString()); + return withRaw( + PARAM_LOGIT_BIAS, + serializer.buildTokenStringBiasArray(logitBias).toString()); } /** @@ -525,7 +531,9 @@ public InferenceParameters withDisabledTokens(Collection tokens) { if (tokens.isEmpty()) { return this; } - return withRaw(PARAM_LOGIT_BIAS, serializer.buildDisableTokenStringArray(tokens).toString()); + return withRaw( + PARAM_LOGIT_BIAS, + serializer.buildDisableTokenStringArray(tokens).toString()); } /** @@ -582,7 +590,9 @@ public InferenceParameters withChatTemplate(@Nullable String chatTemplate) { * @return a new instance; this instance is unchanged */ public InferenceParameters withChatTemplateKwargs(Map kwargs) { - return withRaw(PARAM_CHAT_TEMPLATE_KWARGS, serializer.buildRawValueObject(kwargs).toString()); + return withRaw( + PARAM_CHAT_TEMPLATE_KWARGS, + serializer.buildRawValueObject(kwargs).toString()); } /** @@ -594,12 +604,14 @@ public InferenceParameters withChatTemplateKwargs(Map kwargs) { * @return a new instance; this instance is unchanged */ public InferenceParameters withMessages(@Nullable String systemMessage, List> messages) { - return withRaw(PARAM_MESSAGES, serializer.buildMessages(systemMessage, messages).toString()); + return withRaw( + PARAM_MESSAGES, + serializer.buildMessages(systemMessage, messages).toString()); } /** * Returns a new request with chat messages replaced (multimodal-capable variant). - * Messages with non-null {@link ChatMessage#getParts()} are serialized as OAI + * Messages with non-null {@link net.ladenthin.llama.value.ChatMessage#getParts()} are serialized as OAI * array-form content (text + image_url parts). * * @param messages ordered messages, including any {@code "system"} prelude @@ -694,13 +706,13 @@ public InferenceParameters withContinueFinalMessage(ContinuationMode mode) { /** * Package-private: returns a new request with the {@code stream} flag replaced. - * Used by {@link LlamaModel} and {@link LlamaIterator} to pin the streaming mode + * Used by {@link net.ladenthin.llama.LlamaModel} and {@link net.ladenthin.llama.LlamaIterator} to pin the streaming mode * for each request without mutating the caller's instance. * * @param stream whether to enable streaming * @return a new instance; this instance is unchanged */ - InferenceParameters withStream(boolean stream) { + public InferenceParameters withStream(boolean stream) { return withScalar(PARAM_STREAM, stream); } } diff --git a/src/main/java/net/ladenthin/llama/JsonParameters.java b/src/main/java/net/ladenthin/llama/parameters/JsonParameters.java similarity index 97% rename from src/main/java/net/ladenthin/llama/JsonParameters.java rename to src/main/java/net/ladenthin/llama/parameters/JsonParameters.java index cf3415ad..4c5d5404 100644 --- a/src/main/java/net/ladenthin/llama/JsonParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/JsonParameters.java @@ -3,14 +3,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import java.util.Collections; import java.util.HashMap; import java.util.Map; import lombok.EqualsAndHashCode; import net.ladenthin.llama.args.CliArg; -import net.ladenthin.llama.json.ParameterJsonSerializer; import org.jspecify.annotations.Nullable; /** @@ -148,7 +147,7 @@ protected final T withScalar(String key, Object value * of the given enum constant. * * @param key the parameter key - * @param value the enum constant; must implement {@link CliArg} + * @param value the enum constant; must implement {@link net.ladenthin.llama.args.CliArg} * @param the concrete subtype of this parameter set * @return a new instance with the entry inserted or replaced */ diff --git a/src/main/java/net/ladenthin/llama/ModelParameters.java b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java similarity index 97% rename from src/main/java/net/ladenthin/llama/ModelParameters.java rename to src/main/java/net/ladenthin/llama/parameters/ModelParameters.java index 3cb48c6f..464e7e36 100644 --- a/src/main/java/net/ladenthin/llama/ModelParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java @@ -3,14 +3,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import lombok.EqualsAndHashCode; import net.ladenthin.llama.args.*; -import net.ladenthin.llama.json.ParameterJsonSerializer; /*** - * Parameters used for initializing a {@link LlamaModel}. + * Parameters used for initializing a {@link net.ladenthin.llama.LlamaModel}. * *

{@code equals}/{@code hashCode} are generated by Lombok with {@code callSuper=true} * so the parent {@link CliParameters} parameters map participates in equality. The @@ -116,8 +115,7 @@ public ModelParameters setCpuStrict(int strictCpu) { public ModelParameters setPriority(int priority) { if (priority < 0 || priority > 3) { throw new IllegalArgumentException( - "Invalid value for priority: " + priority - + " (allowed: 0=normal, 1=medium, 2=high, 3=realtime)"); + "Invalid value for priority: " + priority + " (allowed: 0=normal, 1=medium, 2=high, 3=realtime)"); } return putScalar("--prio", priority); } @@ -172,9 +170,8 @@ public ModelParameters setCpuStrictBatch(int strictCpuBatch) { */ public ModelParameters setPriorityBatch(int priorityBatch) { if (priorityBatch < 0 || priorityBatch > 3) { - throw new IllegalArgumentException( - "Invalid value for priority batch: " + priorityBatch - + " (allowed: 0=normal, 1=medium, 2=high, 3=realtime)"); + throw new IllegalArgumentException("Invalid value for priority batch: " + priorityBatch + + " (allowed: 0=normal, 1=medium, 2=high, 3=realtime)"); } return putScalar("--prio-batch", priorityBatch); } @@ -430,8 +427,7 @@ public ModelParameters setTypical(float typP) { public ModelParameters setRepeatLastN(int repeatLastN) { if (repeatLastN < -1) { throw new IllegalArgumentException( - "Invalid repeat-last-n value: " + repeatLastN - + " (must be >= -1; -1 = ctx_size, 0 = disabled)"); + "Invalid repeat-last-n value: " + repeatLastN + " (must be >= -1; -1 = ctx_size, 0 = disabled)"); } return putScalar("--repeat-last-n", repeatLastN); } @@ -504,9 +500,8 @@ public ModelParameters setDryAllowedLength(int dryAllowedLength) { */ public ModelParameters setDryPenaltyLastN(int dryPenaltyLastN) { if (dryPenaltyLastN < -1) { - throw new IllegalArgumentException( - "Invalid dry-penalty-last-n value: " + dryPenaltyLastN - + " (must be >= -1; -1 = context size, 0 = disabled)"); + throw new IllegalArgumentException("Invalid dry-penalty-last-n value: " + dryPenaltyLastN + + " (must be >= -1; -1 = context size, 0 = disabled)"); } return putScalar("--dry-penalty-last-n", dryPenaltyLastN); } @@ -1350,8 +1345,8 @@ public ModelParameters enableJinja() { /** * Only load the vocabulary for tokenization, no weights (default: false). - * A model loaded with this option can only be used for {@link LlamaModel#encode(String)} - * and {@link LlamaModel#decode(int[])}. Inference, embedding, and reranking will not work. + * A model loaded with this option can only be used for {@link net.ladenthin.llama.LlamaModel#encode(String)} + * and {@link net.ladenthin.llama.LlamaModel#decode(int[])}. Inference, embedding, and reranking will not work. * * @return this builder */ @@ -1415,7 +1410,7 @@ public ModelParameters setClearIdle(boolean clearIdle) { /** * Enable the given flag, adding it to the active parameter set. * Equivalent to calling the specific named method (e.g. {@link #enableFlashAttn()} - * for {@link ModelFlag#FLASH_ATTN}). + * for {@link net.ladenthin.llama.args.ModelFlag#FLASH_ATTN}). * * @param flag the flag to enable * @return this builder @@ -1452,14 +1447,14 @@ public boolean hasFlag(ModelFlag flag) { * *

When enabled, the upstream loader will NOT attempt any outbound network call to * download the configured model. If the model file is missing or invalid (e.g. ETag - * mismatch), {@link LlamaModel#LlamaModel(ModelParameters)} throws a typed - * {@link ModelUnavailableException} so the caller can distinguish an air-gapped miss + * mismatch), {@link net.ladenthin.llama.LlamaModel#LlamaModel(ModelParameters)} throws a typed + * {@link net.ladenthin.llama.exception.ModelUnavailableException} so the caller can distinguish an air-gapped miss * from a genuine misconfiguration.

* *

Useful for air-gapped / pre-staged-model deployments where any outbound network * call is itself a failure mode.

* - * @param skip {@code true} to skip downloads (set {@link ModelFlag#SKIP_DOWNLOAD}), + * @param skip {@code true} to skip downloads (set {@link net.ladenthin.llama.args.ModelFlag#SKIP_DOWNLOAD}), * {@code false} to clear the flag and allow downloads * @return this builder */ diff --git a/src/main/java/net/ladenthin/llama/json/ParameterJsonSerializer.java b/src/main/java/net/ladenthin/llama/parameters/ParameterJsonSerializer.java similarity index 96% rename from src/main/java/net/ladenthin/llama/json/ParameterJsonSerializer.java rename to src/main/java/net/ladenthin/llama/parameters/ParameterJsonSerializer.java index e6df169d..b07ddc59 100644 --- a/src/main/java/net/ladenthin/llama/json/ParameterJsonSerializer.java +++ b/src/main/java/net/ladenthin/llama/parameters/ParameterJsonSerializer.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama.json; +package net.ladenthin.llama.parameters; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -14,10 +14,10 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import net.ladenthin.llama.ChatMessage; -import net.ladenthin.llama.ContentPart; -import net.ladenthin.llama.Pair; import net.ladenthin.llama.args.Sampler; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ContentPart; +import net.ladenthin.llama.value.Pair; import org.jspecify.annotations.Nullable; /** @@ -104,8 +104,8 @@ public ArrayNode buildMessages(@Nullable String systemMessage, List +// +// SPDX-License-Identifier: MIT + +/** + * Builder-style parameter objects and their JSON serialization. + * + *

JSpecify {@code @NullMarked} is applied module-wide; everything is non-null + * unless annotated {@code @Nullable}. + */ +package net.ladenthin.llama.parameters; diff --git a/src/main/java/net/ladenthin/llama/ChatChoice.java b/src/main/java/net/ladenthin/llama/value/ChatChoice.java similarity index 97% rename from src/main/java/net/ladenthin/llama/ChatChoice.java rename to src/main/java/net/ladenthin/llama/value/ChatChoice.java index 2ab3db5f..ad1c75f8 100644 --- a/src/main/java/net/ladenthin/llama/ChatChoice.java +++ b/src/main/java/net/ladenthin/llama/value/ChatChoice.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import lombok.EqualsAndHashCode; import lombok.ToString; diff --git a/src/main/java/net/ladenthin/llama/ChatMessage.java b/src/main/java/net/ladenthin/llama/value/ChatMessage.java similarity index 96% rename from src/main/java/net/ladenthin/llama/ChatMessage.java rename to src/main/java/net/ladenthin/llama/value/ChatMessage.java index 1a86eb43..220204c7 100644 --- a/src/main/java/net/ladenthin/llama/ChatMessage.java +++ b/src/main/java/net/ladenthin/llama/value/ChatMessage.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.Arrays; import java.util.Collections; @@ -13,8 +13,8 @@ /** * A single message in a chat conversation: a role ({@code "user"}, {@code "assistant"}, - * {@code "system"}, or {@code "tool"}) and its textual content. Used by {@link Session} - * to accumulate conversation turns and by {@link ChatRequest} / {@link ChatResponse} + * {@code "system"}, or {@code "tool"}) and its textual content. Used by {@link net.ladenthin.llama.Session} + * to accumulate conversation turns and by {@link net.ladenthin.llama.parameters.ChatRequest} / {@link ChatResponse} * for the typed chat API. *

* Tool-call turns have role {@code "assistant"}, possibly empty content, and a non-empty @@ -25,7 +25,7 @@ * Multimodal turns carry a non-null {@link #getParts()} list of {@link ContentPart}s * (text and image references). When parts are present they take precedence over * {@link #getContent()} during serialization; the upstream OAI chat path - * (see {@link InferenceParameters#withMessages(java.util.List)}) emits an array-form + * (see {@link net.ladenthin.llama.parameters.InferenceParameters#withMessages(java.util.List)}) emits an array-form * {@code content} field that the compiled-in {@code mtmd} pipeline understands. *

* diff --git a/src/main/java/net/ladenthin/llama/ChatResponse.java b/src/main/java/net/ladenthin/llama/value/ChatResponse.java similarity index 93% rename from src/main/java/net/ladenthin/llama/ChatResponse.java rename to src/main/java/net/ladenthin/llama/value/ChatResponse.java index e2e8a0fe..7b68d5fe 100644 --- a/src/main/java/net/ladenthin/llama/ChatResponse.java +++ b/src/main/java/net/ladenthin/llama/value/ChatResponse.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.Collections; import java.util.List; @@ -11,8 +11,8 @@ import lombok.ToString; /** - * Typed result of {@link LlamaModel#chat(ChatRequest)} and - * {@link LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. + * Typed result of {@link net.ladenthin.llama.LlamaModel#chat(ChatRequest)} and + * {@link net.ladenthin.llama.LlamaModel#chatWithTools(ChatRequest, java.util.Map)}. *

* Bundles the OpenAI-style {@code id} and {@code choices} array with the per-completion * {@link Usage} and {@link Timings} parsed from the response, plus a passthrough to the diff --git a/src/main/java/net/ladenthin/llama/ChatTranscript.java b/src/main/java/net/ladenthin/llama/value/ChatTranscript.java similarity index 83% rename from src/main/java/net/ladenthin/llama/ChatTranscript.java rename to src/main/java/net/ladenthin/llama/value/ChatTranscript.java index f5981ff9..2bd81bf2 100644 --- a/src/main/java/net/ladenthin/llama/ChatTranscript.java +++ b/src/main/java/net/ladenthin/llama/value/ChatTranscript.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.ArrayList; import java.util.Collections; @@ -12,9 +12,9 @@ /** * Append-only transcript of a multi-turn chat conversation, with an optional - * leading {@code system} message. Extracted from {@link Session} so the + * leading {@code system} message. Extracted from {@link net.ladenthin.llama.Session} so the * transcript invariants — especially the two-phase commit shape — are - * testable independently of {@link LlamaModel} and its native library. + * testable independently of {@link net.ladenthin.llama.LlamaModel} and its native library. * *

Two-phase commit invariant

* @@ -23,11 +23,11 @@ *
    *
  • {@link #appendRound(String, String)} appends a user turn AND an * assistant turn in one synchronised operation — used by - * {@link Session#send(String)} on the model-success path. There is no + * {@link net.ladenthin.llama.Session#send(String)} on the model-success path. There is no * way to commit only one half: if the model call throws, this method * is simply never called and the transcript is untouched.
  • *
  • {@link #appendUserTurn(String)} appends only the user turn — used - * by {@link Session#stream(String)} when the streaming iterable has + * by {@link net.ladenthin.llama.Session#stream(String)} when the streaming iterable has * been successfully created but the assistant reply is still being * accumulated. The matching assistant turn is appended later via * {@link #appendAssistantTurn(String)}.
  • @@ -41,7 +41,7 @@ * *

    Thread safety

    * - *

    This class is not internally synchronised. {@link Session} owns + *

    This class is not internally synchronised. {@link net.ladenthin.llama.Session} owns * the single instance and serialises access via its intrinsic lock, so the * transcript itself does not need additional synchronisation. Callers that * use {@code ChatTranscript} directly must provide their own synchronisation @@ -52,11 +52,11 @@ *

    Lombok-generated over the system message and turns list. The turns list * IS included because it is the operationally interesting state for log * traces. {@code equals}/{@code hashCode} are intentionally NOT generated: - * a transcript instance is identified by its lifecycle owner ({@link Session}), + * a transcript instance is identified by its lifecycle owner ({@link net.ladenthin.llama.Session}), * not by its accumulated content. */ @ToString -final class ChatTranscript { +public final class ChatTranscript { private final @Nullable String systemMessage; private final List> turns = new ArrayList>(); @@ -67,7 +67,7 @@ final class ChatTranscript { * @param systemMessage the system prompt to prepend to every wire-format * prompt; {@code null} or empty means "no system message" */ - ChatTranscript(@Nullable String systemMessage) { + public ChatTranscript(@Nullable String systemMessage) { this.systemMessage = systemMessage; } @@ -80,7 +80,7 @@ final class ChatTranscript { * @param userMessage the user turn * @param assistantMessage the assistant reply that completes the round */ - void appendRound(String userMessage, String assistantMessage) { + public void appendRound(String userMessage, String assistantMessage) { turns.add(new Pair("user", userMessage)); turns.add(new Pair("assistant", assistantMessage)); } @@ -92,7 +92,7 @@ void appendRound(String userMessage, String assistantMessage) { * * @param userMessage the user turn */ - void appendUserTurn(String userMessage) { + public void appendUserTurn(String userMessage) { turns.add(new Pair("user", userMessage)); } @@ -102,7 +102,7 @@ void appendUserTurn(String userMessage) { * * @param assistantMessage the assistant reply */ - void appendAssistantTurn(String assistantMessage) { + public void appendAssistantTurn(String assistantMessage) { turns.add(new Pair("assistant", assistantMessage)); } @@ -116,7 +116,7 @@ void appendAssistantTurn(String assistantMessage) { * @return a fresh list containing the committed turns followed by the * pending user turn */ - List> messagesWithPendingUserTurn(String pendingUserMessage) { + public List> messagesWithPendingUserTurn(String pendingUserMessage) { List> wire = new ArrayList>(turns.size() + 1); wire.addAll(turns); wire.add(new Pair("user", pendingUserMessage)); @@ -128,8 +128,7 @@ List> messagesWithPendingUserTurn(String pendingUserMessage * * @return the system prompt, or {@code null} */ - @Nullable - String getSystemMessage() { + public @Nullable String getSystemMessage() { return systemMessage; } @@ -139,7 +138,7 @@ String getSystemMessage() { * * @return the unmodifiable snapshot */ - List snapshot() { + public List snapshot() { List out = new ArrayList(turns.size() + 1); if (systemMessage != null && !systemMessage.isEmpty()) { out.add(new ChatMessage("system", systemMessage)); @@ -156,7 +155,7 @@ List snapshot() { * * @return the turn count */ - int size() { + public int size() { return turns.size(); } } diff --git a/src/main/java/net/ladenthin/llama/CompletionResult.java b/src/main/java/net/ladenthin/llama/value/CompletionResult.java similarity index 91% rename from src/main/java/net/ladenthin/llama/CompletionResult.java rename to src/main/java/net/ladenthin/llama/value/CompletionResult.java index 8fbd251f..934a56d8 100644 --- a/src/main/java/net/ladenthin/llama/CompletionResult.java +++ b/src/main/java/net/ladenthin/llama/value/CompletionResult.java @@ -2,18 +2,19 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.Collections; import java.util.List; import lombok.EqualsAndHashCode; +import net.ladenthin.llama.parameters.InferenceParameters; /** - * Typed result of {@link LlamaModel#completeWithStats(InferenceParameters)}. + * Typed result of {@link net.ladenthin.llama.LlamaModel#completeWithStats(InferenceParameters)}. *

    * Bundles the generated text with parsed {@link Usage}, {@link Timings}, * per-token {@link TokenLogprob} entries (populated only when - * {@link InferenceParameters#withNProbs(int)} > 0), and the {@link StopReason}. + * {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} > 0), and the {@link StopReason}. * The raw native JSON is exposed via {@link #getRawJson()} as an escape hatch. *

    * diff --git a/src/main/java/net/ladenthin/llama/ContentPart.java b/src/main/java/net/ladenthin/llama/value/ContentPart.java similarity index 95% rename from src/main/java/net/ladenthin/llama/ContentPart.java rename to src/main/java/net/ladenthin/llama/value/ContentPart.java index f73cb7e3..0eb4c0b9 100644 --- a/src/main/java/net/ladenthin/llama/ContentPart.java +++ b/src/main/java/net/ladenthin/llama/value/ContentPart.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.io.IOException; import java.nio.file.Files; @@ -28,7 +28,7 @@ * * and the upstream {@code oaicompat_chat_params_parse} routes it through the * compiled-in {@code mtmd} pipeline (requires - * {@link ModelParameters#setMmproj(String)} to be wired). + * {@link net.ladenthin.llama.parameters.ModelParameters#setMmproj(String)} to be wired). *

    * Instances are immutable and safe to share across threads. Use the static * factories — the constructor is private. @@ -92,8 +92,7 @@ public static ContentPart imageBytes(byte[] bytes, String mimeType) { Objects.requireNonNull(bytes, "bytes"); Objects.requireNonNull(mimeType, "mimeType"); if (mimeType.isEmpty()) { - throw new IllegalArgumentException( - "mimeType must not be empty (bytes.length=" + bytes.length + ")"); + throw new IllegalArgumentException("mimeType must not be empty (bytes.length=" + bytes.length + ")"); } String encoded = Base64.getEncoder().encodeToString(bytes); return new ContentPart(Type.IMAGE_URL, null, "data:" + mimeType + ";base64," + encoded); diff --git a/src/main/java/net/ladenthin/llama/LlamaOutput.java b/src/main/java/net/ladenthin/llama/value/LlamaOutput.java similarity index 88% rename from src/main/java/net/ladenthin/llama/LlamaOutput.java rename to src/main/java/net/ladenthin/llama/value/LlamaOutput.java index 9708e133..e9285375 100644 --- a/src/main/java/net/ladenthin/llama/LlamaOutput.java +++ b/src/main/java/net/ladenthin/llama/value/LlamaOutput.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.Collections; import java.util.List; @@ -12,7 +12,7 @@ /** * An output of the LLM providing access to the generated text and the associated probabilities. You have to configure - * {@link InferenceParameters#withNProbs(int)} in order for probabilities to be returned. + * {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} in order for probabilities to be returned. * *

    {@code equals}/{@code hashCode} are generated by Lombok over all fields. * {@code toString} is intentionally handwritten (not Lombok-generated): it returns @@ -34,13 +34,13 @@ public final class LlamaOutput { * raw {@code prob} or {@code logprob} from the native response. For richer per-token * detail (token id and the {@code top_logprobs} alternatives), use {@link #logprobs}. *

    - * Note, that you have to configure {@link InferenceParameters#withNProbs(int)} in order for probabilities to be returned. + * Note, that you have to configure {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} in order for probabilities to be returned. */ public final Map probabilities; /** * Typed per-token logprob entries with token id and {@code top_logprobs} alternatives. - * Empty when {@link InferenceParameters#withNProbs(int)} is not configured or the native + * Empty when {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} is not configured or the native * response did not include {@code completion_probabilities}. */ public final List logprobs; diff --git a/src/main/java/net/ladenthin/llama/LogLevel.java b/src/main/java/net/ladenthin/llama/value/LogLevel.java similarity index 92% rename from src/main/java/net/ladenthin/llama/LogLevel.java rename to src/main/java/net/ladenthin/llama/value/LogLevel.java index dde6f142..4ddb3303 100644 --- a/src/main/java/net/ladenthin/llama/LogLevel.java +++ b/src/main/java/net/ladenthin/llama/value/LogLevel.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; /** * This enum represents the native log levels of llama.cpp. diff --git a/src/main/java/net/ladenthin/llama/ModelMeta.java b/src/main/java/net/ladenthin/llama/value/ModelMeta.java similarity index 92% rename from src/main/java/net/ladenthin/llama/ModelMeta.java rename to src/main/java/net/ladenthin/llama/value/ModelMeta.java index ef90d331..d9e574ee 100644 --- a/src/main/java/net/ladenthin/llama/ModelMeta.java +++ b/src/main/java/net/ladenthin/llama/value/ModelMeta.java @@ -3,13 +3,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import com.fasterxml.jackson.databind.JsonNode; import lombok.EqualsAndHashCode; /** - * Model metadata returned by {@link LlamaModel#getModelMeta()}. + * Model metadata returned by {@link net.ladenthin.llama.LlamaModel#getModelMeta()}. *

    * Typed getters cover all fields currently returned by the native {@code model_meta()} * function. The underlying {@link JsonNode} is also exposed via {@link #asJson()} so @@ -27,7 +27,12 @@ public final class ModelMeta { private final JsonNode node; - ModelMeta(JsonNode node) { + /** + * Wraps the raw model-metadata JSON node returned by the native layer. + * + * @param node the JSON node holding the model metadata + */ + public ModelMeta(JsonNode node) { this.node = node; } diff --git a/src/main/java/net/ladenthin/llama/Pair.java b/src/main/java/net/ladenthin/llama/value/Pair.java similarity index 96% rename from src/main/java/net/ladenthin/llama/Pair.java rename to src/main/java/net/ladenthin/llama/value/Pair.java index 22074ac4..0c123e65 100644 --- a/src/main/java/net/ladenthin/llama/Pair.java +++ b/src/main/java/net/ladenthin/llama/value/Pair.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import lombok.EqualsAndHashCode; import lombok.ToString; diff --git a/src/main/java/net/ladenthin/llama/ServerMetrics.java b/src/main/java/net/ladenthin/llama/value/ServerMetrics.java similarity index 94% rename from src/main/java/net/ladenthin/llama/ServerMetrics.java rename to src/main/java/net/ladenthin/llama/value/ServerMetrics.java index 883ec3cc..e07afb6e 100644 --- a/src/main/java/net/ladenthin/llama/ServerMetrics.java +++ b/src/main/java/net/ladenthin/llama/value/ServerMetrics.java @@ -2,13 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import com.fasterxml.jackson.databind.JsonNode; import lombok.EqualsAndHashCode; /** - * Typed view over the JSON returned by {@link LlamaModel#getMetrics()}. + * Typed view over the JSON returned by {@link net.ladenthin.llama.LlamaModel#getMetrics()}. *

    * Wraps the underlying {@link JsonNode} so future fields added on the C++ side remain * accessible via {@link #asJson()} without code changes here. Mirrors the @@ -35,7 +35,12 @@ public final class ServerMetrics { private final JsonNode node; - ServerMetrics(JsonNode node) { + /** + * Wraps the raw server-metrics JSON node returned by the native layer. + * + * @param node the JSON node holding the server metrics + */ + public ServerMetrics(JsonNode node) { this.node = node; } diff --git a/src/main/java/net/ladenthin/llama/StopReason.java b/src/main/java/net/ladenthin/llama/value/StopReason.java similarity index 98% rename from src/main/java/net/ladenthin/llama/StopReason.java rename to src/main/java/net/ladenthin/llama/value/StopReason.java index c31c2809..ba7e21b3 100644 --- a/src/main/java/net/ladenthin/llama/StopReason.java +++ b/src/main/java/net/ladenthin/llama/value/StopReason.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import org.jspecify.annotations.Nullable; diff --git a/src/main/java/net/ladenthin/llama/Timings.java b/src/main/java/net/ladenthin/llama/value/Timings.java similarity index 99% rename from src/main/java/net/ladenthin/llama/Timings.java rename to src/main/java/net/ladenthin/llama/value/Timings.java index 57f58e21..090ff154 100644 --- a/src/main/java/net/ladenthin/llama/Timings.java +++ b/src/main/java/net/ladenthin/llama/value/Timings.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import com.fasterxml.jackson.databind.JsonNode; import lombok.EqualsAndHashCode; diff --git a/src/main/java/net/ladenthin/llama/TokenLogprob.java b/src/main/java/net/ladenthin/llama/value/TokenLogprob.java similarity index 95% rename from src/main/java/net/ladenthin/llama/TokenLogprob.java rename to src/main/java/net/ladenthin/llama/value/TokenLogprob.java index 8247d45f..645a0f96 100644 --- a/src/main/java/net/ladenthin/llama/TokenLogprob.java +++ b/src/main/java/net/ladenthin/llama/value/TokenLogprob.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import java.util.Collections; import java.util.List; @@ -12,7 +12,7 @@ /** * Per-token log-probability entry from the native {@code completion_probabilities} array. *

    - * Populated when {@link InferenceParameters#withNProbs(int)} is > 0. The native server + * Populated when {@link net.ladenthin.llama.parameters.InferenceParameters#withNProbs(int)} is > 0. The native server * emits one of two equivalent shapes depending on whether post-sampling probabilities are * enabled: *

    diff --git a/src/main/java/net/ladenthin/llama/ToolCall.java b/src/main/java/net/ladenthin/llama/value/ToolCall.java similarity index 92% rename from src/main/java/net/ladenthin/llama/ToolCall.java rename to src/main/java/net/ladenthin/llama/value/ToolCall.java index 288d7c5c..acce5a3f 100644 --- a/src/main/java/net/ladenthin/llama/ToolCall.java +++ b/src/main/java/net/ladenthin/llama/value/ToolCall.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import lombok.EqualsAndHashCode; @@ -11,7 +11,7 @@ * {@code tool_calls[i]} object: an id, a function name, and the arguments as a JSON string. *

    * Arguments are surfaced verbatim as the JSON string the model emitted; callers parse them - * with their preferred JSON library (or hand them to a {@link ToolHandler}). + * with their preferred JSON library (or hand them to a {@link net.ladenthin.llama.callback.ToolHandler}). *

    * *

    {@code equals}/{@code hashCode} are generated by Lombok over all fields. diff --git a/src/main/java/net/ladenthin/llama/ToolDefinition.java b/src/main/java/net/ladenthin/llama/value/ToolDefinition.java similarity index 98% rename from src/main/java/net/ladenthin/llama/ToolDefinition.java rename to src/main/java/net/ladenthin/llama/value/ToolDefinition.java index bb005b7d..1b8732b2 100644 --- a/src/main/java/net/ladenthin/llama/ToolDefinition.java +++ b/src/main/java/net/ladenthin/llama/value/ToolDefinition.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import lombok.EqualsAndHashCode; import lombok.ToString; diff --git a/src/main/java/net/ladenthin/llama/Usage.java b/src/main/java/net/ladenthin/llama/value/Usage.java similarity index 98% rename from src/main/java/net/ladenthin/llama/Usage.java rename to src/main/java/net/ladenthin/llama/value/Usage.java index 72d8db06..7921634b 100644 --- a/src/main/java/net/ladenthin/llama/Usage.java +++ b/src/main/java/net/ladenthin/llama/value/Usage.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import lombok.EqualsAndHashCode; import lombok.ToString; diff --git a/src/main/java/net/ladenthin/llama/value/package-info.java b/src/main/java/net/ladenthin/llama/value/package-info.java new file mode 100644 index 00000000..458d7ae7 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/value/package-info.java @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +/** + * Immutable value / DTO types (chat, completion, timing, usage and metric records). + * + *

    JSpecify {@code @NullMarked} is applied module-wide; everything is non-null + * unless annotated {@code @Nullable}. + */ +package net.ladenthin.llama.value; diff --git a/src/test/java/examples/ChatExample.java b/src/test/java/examples/ChatExample.java index 4a225eea..30d9322a 100644 --- a/src/test/java/examples/ChatExample.java +++ b/src/test/java/examples/ChatExample.java @@ -10,11 +10,11 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import net.ladenthin.llama.InferenceParameters; import net.ladenthin.llama.LlamaModel; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.ModelParameters; -import net.ladenthin.llama.Pair; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.Disabled; // Model file (models/codellama-7b.Q2_K.gguf) is not available in the models directory diff --git a/src/test/java/examples/GrammarExample.java b/src/test/java/examples/GrammarExample.java index 02b97134..7c06c53c 100644 --- a/src/test/java/examples/GrammarExample.java +++ b/src/test/java/examples/GrammarExample.java @@ -5,10 +5,10 @@ package examples; -import net.ladenthin.llama.InferenceParameters; import net.ladenthin.llama.LlamaModel; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.ModelParameters; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; public class GrammarExample { diff --git a/src/test/java/examples/InfillExample.java b/src/test/java/examples/InfillExample.java index 9ef9e1f5..b82479a4 100644 --- a/src/test/java/examples/InfillExample.java +++ b/src/test/java/examples/InfillExample.java @@ -5,10 +5,10 @@ package examples; -import net.ladenthin.llama.InferenceParameters; import net.ladenthin.llama.LlamaModel; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.ModelParameters; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; public class InfillExample { diff --git a/src/test/java/examples/MainExample.java b/src/test/java/examples/MainExample.java index c37c2d97..32bf331f 100644 --- a/src/test/java/examples/MainExample.java +++ b/src/test/java/examples/MainExample.java @@ -9,11 +9,11 @@ import java.io.IOException; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; -import net.ladenthin.llama.InferenceParameters; import net.ladenthin.llama.LlamaModel; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.ModelParameters; import net.ladenthin.llama.args.MiroStat; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; @SuppressWarnings("InfiniteLoopStatement") public class MainExample { diff --git a/src/test/java/net/ladenthin/llama/ChatAdvancedTest.java b/src/test/java/net/ladenthin/llama/ChatAdvancedTest.java index 6f07530f..d44f8f37 100644 --- a/src/test/java/net/ladenthin/llama/ChatAdvancedTest.java +++ b/src/test/java/net/ladenthin/llama/ChatAdvancedTest.java @@ -15,6 +15,10 @@ import net.ladenthin.llama.args.MiroStat; import net.ladenthin.llama.args.Sampler; import net.ladenthin.llama.json.CompletionResponseParser; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/ChatMessageTest.java b/src/test/java/net/ladenthin/llama/ChatMessageTest.java deleted file mode 100644 index c21aba18..00000000 --- a/src/test/java/net/ladenthin/llama/ChatMessageTest.java +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Bernard Ladenthin -// -// SPDX-License-Identifier: MIT - -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.jupiter.api.Test; - -@ClaudeGenerated( - purpose = "Verify ChatMessage value class accessors and toString format used by Session.getMessages().") -public class ChatMessageTest { - - @Test - public void accessors() { - ChatMessage m = new ChatMessage("user", "hi"); - assertEquals("user", m.getRole()); - assertEquals("hi", m.getContent()); - } - - @Test - public void toStringFormat() { - assertEquals("assistant: hello", new ChatMessage("assistant", "hello").toString()); - } -} diff --git a/src/test/java/net/ladenthin/llama/ChatScenarioTest.java b/src/test/java/net/ladenthin/llama/ChatScenarioTest.java index 72f82952..cc712556 100644 --- a/src/test/java/net/ladenthin/llama/ChatScenarioTest.java +++ b/src/test/java/net/ladenthin/llama/ChatScenarioTest.java @@ -13,8 +13,13 @@ import java.util.ArrayList; import java.util.List; import net.ladenthin.llama.args.PoolingType; +import net.ladenthin.llama.exception.LlamaException; import net.ladenthin.llama.json.ChatResponseParser; import net.ladenthin.llama.json.CompletionResponseParser; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/ConfigureParallelInferenceTest.java b/src/test/java/net/ladenthin/llama/ConfigureParallelInferenceTest.java index 16facddd..7751d1af 100644 --- a/src/test/java/net/ladenthin/llama/ConfigureParallelInferenceTest.java +++ b/src/test/java/net/ladenthin/llama/ConfigureParallelInferenceTest.java @@ -8,6 +8,8 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.File; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/ErrorHandlingTest.java b/src/test/java/net/ladenthin/llama/ErrorHandlingTest.java index bc0caf1a..c021cd95 100644 --- a/src/test/java/net/ladenthin/llama/ErrorHandlingTest.java +++ b/src/test/java/net/ladenthin/llama/ErrorHandlingTest.java @@ -8,6 +8,8 @@ import static org.junit.jupiter.api.Assertions.*; import java.io.File; +import net.ladenthin.llama.exception.LlamaException; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java b/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java index 4c7010d9..17f21566 100644 --- a/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java +++ b/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java @@ -5,6 +5,7 @@ import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.fields; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.noClasses; +import static com.tngtech.archunit.library.Architectures.layeredArchitecture; import static com.tngtech.archunit.library.dependencies.SlicesRuleDefinition.slices; import com.tngtech.archunit.core.importer.ImportOption; @@ -63,22 +64,9 @@ public class LlamaArchitectureTest { /** * The {@code args} sub-package is a true leaf: pure enums / constants * ({@code Sampler}, {@code PoolingType}, {@code ModelFlag}, …). It must not - * import anything from elsewhere in the project — neither the root API - * package nor the {@code json} parser package. - * - *

    This pins the only stackable layer relationship in jllama. The - * traditional {@code layeredArchitecture()} 3-layer rule (Args → Json → Api) - * was attempted and rejected: {@code json} parsers/serializers genuinely - * depend on root-package DTOs ({@code Pair}, {@code ChatMessage}, - * {@code ContentPart}) AND the root API genuinely depends on {@code json} - * parsers — they are peers in the public API layer, not a - * stackable hierarchy. Splitting the DTOs into a dedicated - * {@code net.ladenthin.llama.value} package would enable real layering, - * but breaks the published public-API FQNs ({@code net.ladenthin.llama.Pair} - * etc.) and is out of scope for an ArchUnit rule. - * - *

    So the only real architectural invariant worth enforcing here is "args - * stays a leaf" — and that is what this rule does. + * import anything from elsewhere in the project. Subsumed by the + * {@link #layeredArchitecture} rule below (args is in the Foundation layer), + * but kept as a precise, fast-failing guard for this specific leaf. */ @ArchTest static final ArchRule argsPackageIsALeaf = noClasses() @@ -86,7 +74,58 @@ public class LlamaArchitectureTest { .resideInAPackage("net.ladenthin.llama.args..") .should() .dependOnClassesThat() - .resideInAnyPackage("net.ladenthin.llama", "net.ladenthin.llama.json.."); + .resideInAnyPackage( + "net.ladenthin.llama", + "net.ladenthin.llama.callback..", + "net.ladenthin.llama.exception..", + "net.ladenthin.llama.json..", + "net.ladenthin.llama.loader..", + "net.ladenthin.llama.parameters..", + "net.ladenthin.llama.value.."); + + /** + * Strict layered architecture — one layer per package. Each package's + * {@code mayOnlyBeAccessedByLayers} lists the EXACT set of packages that reference it today + * (verified against the compiled bytecode graph), so even intra-tier edges are governed: a + * new dependency between any two packages fails the build unless this rule is updated to + * intend it. Conceptual tiers (informational): {@code Api} (root) > {@code Loader} > + * {@code Json}/{@code Parameters} > {@code Value}/{@code Callback}/{@code Exception}/{@code Args}. + */ + @ArchTest + static final ArchRule layeredArchitecture = layeredArchitecture() + .consideringOnlyDependenciesInLayers() + .layer("Api") + .definedBy("net.ladenthin.llama") + .layer("Loader") + .definedBy("net.ladenthin.llama.loader..") + .layer("Json") + .definedBy("net.ladenthin.llama.json..") + .layer("Parameters") + .definedBy("net.ladenthin.llama.parameters..") + .layer("Value") + .definedBy("net.ladenthin.llama.value..") + .layer("Callback") + .definedBy("net.ladenthin.llama.callback..") + .layer("Exception") + .definedBy("net.ladenthin.llama.exception..") + .layer("Args") + .definedBy("net.ladenthin.llama.args..") + .whereLayer("Api") + .mayNotBeAccessedByAnyLayer() + .whereLayer("Loader") + .mayOnlyBeAccessedByLayers("Api") + .whereLayer("Json") + .mayOnlyBeAccessedByLayers("Api") + .whereLayer("Parameters") + .mayOnlyBeAccessedByLayers("Api", "Loader") + .whereLayer("Value") + .mayOnlyBeAccessedByLayers("Api", "Json", "Parameters") + .whereLayer("Callback") + .mayOnlyBeAccessedByLayers("Api") + .whereLayer("Exception") + .mayOnlyBeAccessedByLayers("Api", "Loader") + .whereLayer("Args") + .mayOnlyBeAccessedByLayers("Api", "Loader", "Parameters"); /** * Production code must not import unsupported / internal JDK packages. @@ -152,4 +191,24 @@ public class LlamaArchitectureTest { .orShould() .callMethod(Thread.class, "sleep", long.class, int.class) .allowEmptyShould(true); + + /** + * Per-module banned import: the foundation contracts ({@code args}, {@code callback}, + * {@code exception}) and the {@code loader} infrastructure must stay free of the Jackson + * JSON library ({@code com.fasterxml.jackson..}). JSON marshalling is the job of + * {@code value} / {@code json} / {@code parameters} (and the root {@code Api}, which drives + * them); these layers carry only plain typed data and native-loading logic. + */ + @ArchTest + static final ArchRule jacksonBannedFromContractsAndLoader = noClasses() + .that() + .resideInAnyPackage( + "net.ladenthin.llama.args..", + "net.ladenthin.llama.callback..", + "net.ladenthin.llama.exception..", + "net.ladenthin.llama.loader..") + .should() + .dependOnClassesThat() + .resideInAPackage("com.fasterxml.jackson..") + .allowEmptyShould(true); } diff --git a/src/test/java/net/ladenthin/llama/LlamaEmbeddingsTest.java b/src/test/java/net/ladenthin/llama/LlamaEmbeddingsTest.java index 0f50fbc9..69c52b73 100644 --- a/src/test/java/net/ladenthin/llama/LlamaEmbeddingsTest.java +++ b/src/test/java/net/ladenthin/llama/LlamaEmbeddingsTest.java @@ -9,6 +9,7 @@ import java.io.File; import net.ladenthin.llama.args.PoolingType; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/LlamaModelSkipDownloadTest.java b/src/test/java/net/ladenthin/llama/LlamaModelSkipDownloadTest.java index dcf4eae5..f305795d 100644 --- a/src/test/java/net/ladenthin/llama/LlamaModelSkipDownloadTest.java +++ b/src/test/java/net/ladenthin/llama/LlamaModelSkipDownloadTest.java @@ -11,6 +11,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import net.ladenthin.llama.args.ModelFlag; +import net.ladenthin.llama.exception.LlamaException; +import net.ladenthin.llama.exception.ModelUnavailableException; +import net.ladenthin.llama.loader.SkipDownloadFailureTranslator; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/LlamaModelTest.java b/src/test/java/net/ladenthin/llama/LlamaModelTest.java index daab1dc6..c89cb3fb 100644 --- a/src/test/java/net/ladenthin/llama/LlamaModelTest.java +++ b/src/test/java/net/ladenthin/llama/LlamaModelTest.java @@ -12,6 +12,20 @@ import java.util.*; import java.util.regex.Pattern; import net.ladenthin.llama.args.LogFormat; +import net.ladenthin.llama.callback.CancellationToken; +import net.ladenthin.llama.callback.ToolHandler; +import net.ladenthin.llama.exception.LlamaException; +import net.ladenthin.llama.parameters.ChatRequest; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.CompletionResult; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.LogLevel; +import net.ladenthin.llama.value.ModelMeta; +import net.ladenthin.llama.value.Pair; +import net.ladenthin.llama.value.ToolDefinition; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/LlamaParameterProperties.java b/src/test/java/net/ladenthin/llama/LlamaParameterProperties.java index 8d58a4a8..34369a30 100644 --- a/src/test/java/net/ladenthin/llama/LlamaParameterProperties.java +++ b/src/test/java/net/ladenthin/llama/LlamaParameterProperties.java @@ -6,6 +6,7 @@ import net.jqwik.api.ForAll; import net.jqwik.api.Property; import net.jqwik.api.constraints.FloatRange; +import net.ladenthin.llama.parameters.InferenceParameters; public class LlamaParameterProperties { diff --git a/src/test/java/net/ladenthin/llama/MemoryManagementTest.java b/src/test/java/net/ladenthin/llama/MemoryManagementTest.java index a846065f..1762b76a 100644 --- a/src/test/java/net/ladenthin/llama/MemoryManagementTest.java +++ b/src/test/java/net/ladenthin/llama/MemoryManagementTest.java @@ -12,6 +12,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/MultimodalIntegrationTest.java b/src/test/java/net/ladenthin/llama/MultimodalIntegrationTest.java index 8f4d4936..693246f8 100644 --- a/src/test/java/net/ladenthin/llama/MultimodalIntegrationTest.java +++ b/src/test/java/net/ladenthin/llama/MultimodalIntegrationTest.java @@ -12,6 +12,10 @@ import java.nio.file.Paths; import java.util.Collections; import java.util.concurrent.TimeUnit; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ContentPart; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/MultimodalMessagesTest.java b/src/test/java/net/ladenthin/llama/MultimodalMessagesTest.java index 9292f98a..5fe004df 100644 --- a/src/test/java/net/ladenthin/llama/MultimodalMessagesTest.java +++ b/src/test/java/net/ladenthin/llama/MultimodalMessagesTest.java @@ -4,10 +4,11 @@ package net.ladenthin.llama; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import com.fasterxml.jackson.databind.JsonNode; @@ -16,7 +17,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import net.ladenthin.llama.json.ParameterJsonSerializer; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ParameterJsonSerializer; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ContentPart; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -31,16 +35,16 @@ public class MultimodalMessagesTest { @Test public void hasPartsIsFalseForLegacyConstructor() { ChatMessage m = new ChatMessage("user", "hello"); - assertFalse(m.hasParts()); - assertTrue(m.getParts().isEmpty()); + assertThat(m.hasParts(), is(false)); + assertThat(m.getParts().isPresent(), is(false)); } @Test public void hasPartsIsTrueForPartsConstructor() { ChatMessage m = new ChatMessage( "user", Arrays.asList(ContentPart.text("hi"), ContentPart.imageUrl("data:image/png;base64,AAAA"))); - assertTrue(m.hasParts()); - assertEquals(2, m.getParts().orElseThrow().size()); + assertThat(m.hasParts(), is(true)); + assertThat(m.getParts().orElseThrow(), hasSize(2)); } @Test @@ -52,18 +56,18 @@ public void contentFieldConcatenatesTextPartsForLegacyReaders() { ContentPart.imageUrl("data:image/png;base64,X"), ContentPart.text("please"))); // Image parts contribute no text; text parts are newline-joined. - assertEquals("describe\nplease", m.getContent()); + assertThat(m.getContent(), is("describe\nplease")); } @Test public void userMultimodalFactoryBuildsUserMessage() { ChatMessage m = ChatMessage.userMultimodal( ContentPart.text("what is this?"), ContentPart.imageUrl("data:image/jpeg;base64,Y")); - assertEquals("user", m.getRole()); + assertThat(m.getRole(), is("user")); List parts = m.getParts().orElseThrow(); - assertEquals(2, parts.size()); - assertEquals(ContentPart.Type.TEXT, parts.get(0).getType()); - assertEquals(ContentPart.Type.IMAGE_URL, parts.get(1).getType()); + assertThat(parts, hasSize(2)); + assertThat(parts.get(0).getType(), is(ContentPart.Type.TEXT)); + assertThat(parts.get(1).getType(), is(ContentPart.Type.IMAGE_URL)); } @Test @@ -95,22 +99,21 @@ public void serializerEmitsArrayContentForPartsMessage() throws Exception { ContentPart.text("describe"), ContentPart.imageUrl("data:image/png;base64,ABCD")); ArrayNode arr = s.buildMessages(Collections.singletonList(user)); - assertEquals(1, arr.size()); + assertThat(arr.size(), is(1)); JsonNode msg = arr.get(0); - assertEquals("user", msg.get("role").asText()); + assertThat(msg.get("role").asText(), is("user")); JsonNode content = msg.get("content"); - assertTrue(content.isArray(), "content must be an array when parts are present"); - assertEquals(2, content.size()); + assertThat("content must be an array when parts are present", content.isArray(), is(true)); + assertThat(content.size(), is(2)); JsonNode p0 = content.get(0); - assertEquals("text", p0.get("type").asText()); - assertEquals("describe", p0.get("text").asText()); + assertThat(p0.get("type").asText(), is("text")); + assertThat(p0.get("text").asText(), is("describe")); JsonNode p1 = content.get(1); - assertEquals("image_url", p1.get("type").asText()); - assertEquals( - "data:image/png;base64,ABCD", p1.get("image_url").get("url").asText()); + assertThat(p1.get("type").asText(), is("image_url")); + assertThat(p1.get("image_url").get("url").asText(), is("data:image/png;base64,ABCD")); } @Test @@ -119,11 +122,14 @@ public void serializerEmitsStringContentForLegacyMessage() { ChatMessage user = new ChatMessage("user", "plain text"); ArrayNode arr = s.buildMessages(Collections.singletonList(user)); - assertEquals(1, arr.size()); + assertThat(arr.size(), is(1)); JsonNode msg = arr.get(0); - assertEquals("user", msg.get("role").asText()); - assertTrue(msg.get("content").isTextual(), "content must remain a string for legacy messages"); - assertEquals("plain text", msg.get("content").asText()); + assertThat(msg.get("role").asText(), is("user")); + assertThat( + "content must remain a string for legacy messages", + msg.get("content").isTextual(), + is(true)); + assertThat(msg.get("content").asText(), is("plain text")); } @Test @@ -135,23 +141,23 @@ public void serializerHandlesMixedMessages() { ContentPart.text("what's in here?"), ContentPart.imageUrl("data:image/png;base64,Z")), new ChatMessage("assistant", "a cat")); ArrayNode arr = s.buildMessages(messages); - assertEquals(3, arr.size()); - assertTrue(arr.get(0).get("content").isTextual()); - assertTrue(arr.get(1).get("content").isArray()); - assertTrue(arr.get(2).get("content").isTextual()); + assertThat(arr.size(), is(3)); + assertThat(arr.get(0).get("content").isTextual(), is(true)); + assertThat(arr.get(1).get("content").isArray(), is(true)); + assertThat(arr.get(2).get("content").isTextual(), is(true)); } @Test public void inferenceParametersAcceptsMultimodalMessages() { InferenceParameters params = new InferenceParameters("") - .withMessages(Collections.singletonList( - ChatMessage.userMultimodal(ContentPart.text("hi"), ContentPart.imageUrl("data:image/png;base64,QQ")))); + .withMessages(Collections.singletonList(ChatMessage.userMultimodal( + ContentPart.text("hi"), ContentPart.imageUrl("data:image/png;base64,QQ")))); // setMessages encodes into the parameters map under "messages"; verify the // resulting JSON has the array form, which is what the upstream OAI chat // parser expects for multimodal routing. String json = params.toString(); - assertTrue(json.contains("\"messages\""), "messages array must be present"); - assertTrue(json.contains("\"image_url\""), "multimodal part type must be in the serialised JSON"); - assertTrue(json.contains("data:image/png;base64,QQ"), "data URI must round-trip into the request body"); + assertThat("messages array must be present", json, containsString("\"messages\"")); + assertThat("multimodal part type must be in the serialised JSON", json, containsString("\"image_url\"")); + assertThat("data URI must round-trip into the request body", json, containsString("data:image/png;base64,QQ")); } } diff --git a/src/test/java/net/ladenthin/llama/ReactorIntegrationTest.java b/src/test/java/net/ladenthin/llama/ReactorIntegrationTest.java index 36fe251c..5f0a2c91 100644 --- a/src/test/java/net/ladenthin/llama/ReactorIntegrationTest.java +++ b/src/test/java/net/ladenthin/llama/ReactorIntegrationTest.java @@ -14,6 +14,9 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -55,8 +58,7 @@ class ReactorIntegrationTest { @Test void mockIterable_requestBackpressureAndCancelClose() { AtomicBoolean closed = new AtomicBoolean(false); - List tokens = - Arrays.asList(out("a"), out("b"), out("c"), out("d"), out("e")); + List tokens = Arrays.asList(out("a"), out("b"), out("c"), out("d"), out("e")); // Flux.fromIterable(iterable) does NOT auto-close AutoCloseable iterables on cancel — // the canonical Reactor pattern for that is Flux.using(supplier, builder, cleanup). @@ -90,8 +92,7 @@ void mockIterable_requestBackpressureAndCancelClose() { @Test void realModel_cancelPropagatesToNativeCompletion() { Assumptions.assumeTrue( - new File(TestConstants.MODEL_PATH).exists(), - "real-model test requires " + TestConstants.MODEL_PATH); + new File(TestConstants.MODEL_PATH).exists(), "real-model test requires " + TestConstants.MODEL_PATH); ModelParameters mp = new ModelParameters() .setModel(TestConstants.MODEL_PATH) @@ -99,8 +100,9 @@ void realModel_cancelPropagatesToNativeCompletion() { try (LlamaModel model = new LlamaModel(mp)) { // First: stream via Reactor with Flux.using for proper cleanup, take 3 tokens, cancel. String first = Flux.using( - () -> model.generate( - new InferenceParameters("Q: 1+1=").withNPredict(20).withTemperature(0.0f)), + () -> model.generate(new InferenceParameters("Q: 1+1=") + .withNPredict(20) + .withTemperature(0.0f)), Flux::fromIterable, LlamaIterable::close) .subscribeOn(Schedulers.boundedElastic()) @@ -116,8 +118,8 @@ void realModel_cancelPropagatesToNativeCompletion() { // first generation's slot was released by Flux.using's cleanup function // routing through LlamaIterable.close() -> LlamaIterator.close() -> // native cancelCompletion. - String second = model.complete( - new InferenceParameters("Hi").withNPredict(2).withTemperature(0.0f)); + String second = + model.complete(new InferenceParameters("Hi").withNPredict(2).withTemperature(0.0f)); assertNotNull(second); } } diff --git a/src/test/java/net/ladenthin/llama/ReasoningBudgetTest.java b/src/test/java/net/ladenthin/llama/ReasoningBudgetTest.java index 2f516147..36164119 100644 --- a/src/test/java/net/ladenthin/llama/ReasoningBudgetTest.java +++ b/src/test/java/net/ladenthin/llama/ReasoningBudgetTest.java @@ -11,6 +11,9 @@ import java.util.Collections; import net.ladenthin.llama.args.ReasoningFormat; import net.ladenthin.llama.json.ChatResponseParser; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/RerankingModelTest.java b/src/test/java/net/ladenthin/llama/RerankingModelTest.java index 5976de11..937aa5a5 100644 --- a/src/test/java/net/ladenthin/llama/RerankingModelTest.java +++ b/src/test/java/net/ladenthin/llama/RerankingModelTest.java @@ -9,6 +9,9 @@ import java.util.List; import java.util.Map; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/ResponseJsonStructureTest.java b/src/test/java/net/ladenthin/llama/ResponseJsonStructureTest.java index aaaf24e0..c90eedd0 100644 --- a/src/test/java/net/ladenthin/llama/ResponseJsonStructureTest.java +++ b/src/test/java/net/ladenthin/llama/ResponseJsonStructureTest.java @@ -5,10 +5,17 @@ package net.ladenthin.llama; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import java.io.File; import net.ladenthin.llama.args.PoolingType; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; @@ -71,70 +78,70 @@ public static void tearDown() { public void testNonOaiCompletionHasContentField() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"content\""), "Response must contain 'content'"); + assertThat("Response must contain 'content'", result, containsString("\"content\"")); } @Test public void testNonOaiCompletionHasStopField() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"stop\""), "Response must contain 'stop'"); + assertThat("Response must contain 'stop'", result, containsString("\"stop\"")); } @Test public void testNonOaiCompletionHasStopType() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"stop_type\""), "Response must contain 'stop_type'"); + assertThat("Response must contain 'stop_type'", result, containsString("\"stop_type\"")); } @Test public void testNonOaiCompletionHasModelField() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"model\""), "Response must contain 'model'"); + assertThat("Response must contain 'model'", result, containsString("\"model\"")); } @Test public void testNonOaiCompletionHasTokensPredicted() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"tokens_predicted\""), "Response must contain 'tokens_predicted'"); + assertThat("Response must contain 'tokens_predicted'", result, containsString("\"tokens_predicted\"")); } @Test public void testNonOaiCompletionHasTokensEvaluated() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"tokens_evaluated\""), "Response must contain 'tokens_evaluated'"); + assertThat("Response must contain 'tokens_evaluated'", result, containsString("\"tokens_evaluated\"")); } @Test public void testNonOaiCompletionHasTimings() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"timings\""), "Response must contain 'timings'"); + assertThat("Response must contain 'timings'", result, containsString("\"timings\"")); } @Test public void testNonOaiCompletionHasGenerationSettings() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"generation_settings\""), "Response must contain 'generation_settings'"); + assertThat("Response must contain 'generation_settings'", result, containsString("\"generation_settings\"")); } @Test public void testNonOaiCompletionHasTokensCached() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"tokens_cached\""), "Response must contain 'tokens_cached'"); + assertThat("Response must contain 'tokens_cached'", result, containsString("\"tokens_cached\"")); } @Test public void testNonOaiCompletionHasIdSlot() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"id_slot\""), "Response must contain 'id_slot'"); + assertThat("Response must contain 'id_slot'", result, containsString("\"id_slot\"")); } // ------------------------------------------------------------------------- @@ -145,44 +152,45 @@ public void testNonOaiCompletionHasIdSlot() { public void testTimingsHasPromptN() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"prompt_n\""), "Timings must contain 'prompt_n'"); + assertThat("Timings must contain 'prompt_n'", result, containsString("\"prompt_n\"")); } @Test public void testTimingsHasPromptMs() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"prompt_ms\""), "Timings must contain 'prompt_ms'"); + assertThat("Timings must contain 'prompt_ms'", result, containsString("\"prompt_ms\"")); } @Test public void testTimingsHasPredictedN() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"predicted_n\""), "Timings must contain 'predicted_n'"); + assertThat("Timings must contain 'predicted_n'", result, containsString("\"predicted_n\"")); } @Test public void testTimingsHasPredictedMs() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"predicted_ms\""), "Timings must contain 'predicted_ms'"); + assertThat("Timings must contain 'predicted_ms'", result, containsString("\"predicted_ms\"")); } @Test public void testTimingsHasPerTokenFields() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"prompt_per_token_ms\""), "Timings must contain 'prompt_per_token_ms'"); - assertTrue(result.contains("\"predicted_per_token_ms\""), "Timings must contain 'predicted_per_token_ms'"); + assertThat("Timings must contain 'prompt_per_token_ms'", result, containsString("\"prompt_per_token_ms\"")); + assertThat( + "Timings must contain 'predicted_per_token_ms'", result, containsString("\"predicted_per_token_ms\"")); } @Test public void testTimingsHasPerSecondFields() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"prompt_per_second\""), "Timings must contain 'prompt_per_second'"); - assertTrue(result.contains("\"predicted_per_second\""), "Timings must contain 'predicted_per_second'"); + assertThat("Timings must contain 'prompt_per_second'", result, containsString("\"prompt_per_second\"")); + assertThat("Timings must contain 'predicted_per_second'", result, containsString("\"predicted_per_second\"")); } // ------------------------------------------------------------------------- @@ -194,7 +202,10 @@ public void testStopTypeLimitOnMaxTokens() { // n_predict=N_PREDICT with no stop string should result in "limit" stop_type String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"stop_type\":\"limit\""), "stop_type should be 'limit' when max tokens reached"); + assertThat( + "stop_type should be 'limit' when max tokens reached", + result, + containsString("\"stop_type\":\"limit\"")); } @Test @@ -202,11 +213,13 @@ public void testStopTypeWordOnStopString() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":50" + DETERMINISTIC + ",\"stop\":[\"return\"]}"; String result = model.handleCompletions(json); // May be "word" if stop string matched, or "limit" if n_predict reached first - assertTrue( - result.contains("\"stop_type\":\"word\"") - || result.contains("\"stop_type\":\"limit\"") - || result.contains("\"stop_type\":\"eos\""), - "stop_type should be present"); + assertThat( + "stop_type should be present", + result, + anyOf( + containsString("\"stop_type\":\"word\""), + containsString("\"stop_type\":\"limit\""), + containsString("\"stop_type\":\"eos\""))); } // ------------------------------------------------------------------------- @@ -217,67 +230,68 @@ public void testStopTypeWordOnStopString() { public void testOaiCompletionHasChoices() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"choices\""), "OAI response must contain 'choices'"); + assertThat("OAI response must contain 'choices'", result, containsString("\"choices\"")); } @Test public void testOaiCompletionHasUsage() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"usage\""), "OAI response must contain 'usage'"); + assertThat("OAI response must contain 'usage'", result, containsString("\"usage\"")); } @Test public void testOaiCompletionHasObject() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue( - result.contains("\"object\":\"text_completion\""), - "OAI response must contain 'object':'text_completion'"); + assertThat( + "OAI response must contain 'object':'text_completion'", + result, + containsString("\"object\":\"text_completion\"")); } @Test public void testOaiCompletionHasCreated() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"created\""), "OAI response must contain 'created'"); + assertThat("OAI response must contain 'created'", result, containsString("\"created\"")); } @Test public void testOaiCompletionHasModel() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"model\""), "OAI response must contain 'model'"); + assertThat("OAI response must contain 'model'", result, containsString("\"model\"")); } @Test public void testOaiCompletionHasSystemFingerprint() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"system_fingerprint\""), "OAI response must contain 'system_fingerprint'"); + assertThat("OAI response must contain 'system_fingerprint'", result, containsString("\"system_fingerprint\"")); } @Test public void testOaiCompletionHasId() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"id\""), "OAI response must contain 'id'"); + assertThat("OAI response must contain 'id'", result, containsString("\"id\"")); } @Test public void testOaiCompletionUsageFields() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"completion_tokens\""), "Usage must contain 'completion_tokens'"); - assertTrue(result.contains("\"prompt_tokens\""), "Usage must contain 'prompt_tokens'"); - assertTrue(result.contains("\"total_tokens\""), "Usage must contain 'total_tokens'"); + assertThat("Usage must contain 'completion_tokens'", result, containsString("\"completion_tokens\"")); + assertThat("Usage must contain 'prompt_tokens'", result, containsString("\"prompt_tokens\"")); + assertThat("Usage must contain 'total_tokens'", result, containsString("\"total_tokens\"")); } @Test public void testOaiCompletionChoiceHasFinishReason() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); - assertTrue(result.contains("\"finish_reason\""), "Choice must contain 'finish_reason'"); + assertThat("Choice must contain 'finish_reason'", result, containsString("\"finish_reason\"")); } @Test @@ -285,9 +299,10 @@ public void testOaiCompletionFinishReasonLength() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletionsOai(json); // With small n_predict, finish_reason should be "length" - assertTrue( - result.contains("\"finish_reason\":\"length\"") || result.contains("\"finish_reason\":\"stop\""), - "finish_reason should be 'length' or 'stop'"); + assertThat( + "finish_reason should be 'length' or 'stop'", + result, + anyOf(containsString("\"finish_reason\":\"length\""), containsString("\"finish_reason\":\"stop\""))); } // ------------------------------------------------------------------------- @@ -301,7 +316,7 @@ public void testOaiChatCompletionHasChoices() { .withNPredict(N_PREDICT) .withTemperature(0); String result = model.chatComplete(params); - assertTrue(result.contains("\"choices\""), "Chat response must contain 'choices'"); + assertThat("Chat response must contain 'choices'", result, containsString("\"choices\"")); } @Test @@ -311,7 +326,7 @@ public void testOaiChatCompletionHasUsage() { .withNPredict(N_PREDICT) .withTemperature(0); String result = model.chatComplete(params); - assertTrue(result.contains("\"usage\""), "Chat response must contain 'usage'"); + assertThat("Chat response must contain 'usage'", result, containsString("\"usage\"")); } @Test @@ -321,7 +336,7 @@ public void testOaiChatCompletionHasMessageObject() { .withNPredict(N_PREDICT) .withTemperature(0); String result = model.chatComplete(params); - assertTrue(result.contains("\"message\""), "Chat response must contain 'message'"); + assertThat("Chat response must contain 'message'", result, containsString("\"message\"")); } @Test @@ -331,8 +346,10 @@ public void testOaiChatCompletionObjectType() { .withNPredict(N_PREDICT) .withTemperature(0); String result = model.chatComplete(params); - assertTrue( - result.contains("\"object\":\"chat.completion\""), "Chat response 'object' must be 'chat.completion'"); + assertThat( + "Chat response 'object' must be 'chat.completion'", + result, + containsString("\"object\":\"chat.completion\"")); } @Test @@ -342,7 +359,7 @@ public void testOaiChatCompletionMessageHasRole() { .withNPredict(N_PREDICT) .withTemperature(0); String result = model.chatComplete(params); - assertTrue(result.contains("\"role\":\"assistant\""), "Message must contain 'role':'assistant'"); + assertThat("Message must contain 'role':'assistant'", result, containsString("\"role\":\"assistant\"")); } // ------------------------------------------------------------------------- @@ -353,18 +370,19 @@ public void testOaiChatCompletionMessageHasRole() { public void testEmbeddingOaiResponseStructure() { String json = "{\"input\":\"hello world\"}"; String result = model.handleEmbeddings(json, true); - assertTrue(result.contains("\"data\""), "OAI embedding must contain 'data'"); - assertTrue(result.contains("\"object\":\"embedding\""), "OAI embedding must contain 'object':'embedding'"); - assertTrue(result.contains("\"embedding\""), "OAI embedding must contain 'embedding' array"); - assertTrue(result.contains("\"usage\""), "OAI embedding must contain 'usage'"); + assertThat("OAI embedding must contain 'data'", result, containsString("\"data\"")); + assertThat( + "OAI embedding must contain 'object':'embedding'", result, containsString("\"object\":\"embedding\"")); + assertThat("OAI embedding must contain 'embedding' array", result, containsString("\"embedding\"")); + assertThat("OAI embedding must contain 'usage'", result, containsString("\"usage\"")); } @Test public void testEmbeddingNonOaiResponseStructure() { String json = "{\"input\":\"hello world\"}"; String result = model.handleEmbeddings(json, false); - assertTrue(result.contains("\"embedding\""), "Non-OAI embedding must contain 'embedding'"); - assertTrue(result.contains("\"index\""), "Non-OAI embedding must contain 'index'"); + assertThat("Non-OAI embedding must contain 'embedding'", result, containsString("\"embedding\"")); + assertThat("Non-OAI embedding must contain 'index'", result, containsString("\"index\"")); } // ------------------------------------------------------------------------- @@ -374,23 +392,23 @@ public void testEmbeddingNonOaiResponseStructure() { @Test public void testTokenizeResponseStructure() { String result = model.handleTokenize("hello world", false, false); - assertNotNull(result); - assertTrue(result.contains("\"tokens\""), "Tokenize response must contain 'tokens'"); + assertThat(result, is(notNullValue())); + assertThat("Tokenize response must contain 'tokens'", result, containsString("\"tokens\"")); } @Test public void testTokenizeWithPiecesResponseStructure() { String result = model.handleTokenize("hello world", false, true); - assertNotNull(result); - assertTrue(result.contains("\"tokens\""), "Tokenize with pieces must contain 'tokens'"); + assertThat(result, is(notNullValue())); + assertThat("Tokenize with pieces must contain 'tokens'", result, containsString("\"tokens\"")); } @Test public void testDetokenizeResponseStructure() { int[] tokens = model.encode("hello world"); String result = model.handleDetokenize(tokens); - assertNotNull(result); - assertTrue(result.contains("\"content\""), "Detokenize response must contain 'content'"); + assertThat(result, is(notNullValue())); + assertThat("Detokenize response must contain 'content'", result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -401,9 +419,10 @@ public void testDetokenizeResponseStructure() { public void testCompletionProbabilitiesStructure() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"n_probs\":3}"; String result = model.handleCompletions(json); - assertTrue( - result.contains("\"completion_probabilities\""), - "Response with n_probs should contain 'completion_probabilities'"); + assertThat( + "Response with n_probs should contain 'completion_probabilities'", + result, + containsString("\"completion_probabilities\"")); } // ------------------------------------------------------------------------- @@ -415,16 +434,16 @@ public void testGenerationSettingsContainsSamplingParams() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); // generation_settings should echo back the sampling parameters - assertTrue(result.contains("\"temperature\""), "generation_settings should contain 'temperature'"); - assertTrue(result.contains("\"top_k\""), "generation_settings should contain 'top_k'"); - assertTrue(result.contains("\"top_p\""), "generation_settings should contain 'top_p'"); - assertTrue(result.contains("\"min_p\""), "generation_settings should contain 'min_p'"); + assertThat("generation_settings should contain 'temperature'", result, containsString("\"temperature\"")); + assertThat("generation_settings should contain 'top_k'", result, containsString("\"top_k\"")); + assertThat("generation_settings should contain 'top_p'", result, containsString("\"top_p\"")); + assertThat("generation_settings should contain 'min_p'", result, containsString("\"min_p\"")); } @Test public void testGenerationSettingsContainsSamplers() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + "}"; String result = model.handleCompletions(json); - assertTrue(result.contains("\"samplers\""), "generation_settings should contain 'samplers'"); + assertThat("generation_settings should contain 'samplers'", result, containsString("\"samplers\"")); } } diff --git a/src/test/java/net/ladenthin/llama/SessionConcurrencyTest.java b/src/test/java/net/ladenthin/llama/SessionConcurrencyTest.java index edac3777..28f006d1 100644 --- a/src/test/java/net/ladenthin/llama/SessionConcurrencyTest.java +++ b/src/test/java/net/ladenthin/llama/SessionConcurrencyTest.java @@ -17,6 +17,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.LlamaOutput; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/net/ladenthin/llama/TestConstants.java b/src/test/java/net/ladenthin/llama/TestConstants.java index 6d4ed68d..566c6d56 100644 --- a/src/test/java/net/ladenthin/llama/TestConstants.java +++ b/src/test/java/net/ladenthin/llama/TestConstants.java @@ -5,21 +5,23 @@ package net.ladenthin.llama; -class TestConstants { +import net.ladenthin.llama.loader.LlamaSystemProperties; + +public class TestConstants { /** System property to override GPU layers used in tests. */ - static final String PROP_TEST_NGL = LlamaSystemProperties.PREFIX + ".test.ngl"; + public static final String PROP_TEST_NGL = LlamaSystemProperties.PREFIX + ".test.ngl"; - static final int DEFAULT_TEST_NGL = 43; + public static final int DEFAULT_TEST_NGL = 43; /** Path to the main text generation model used in tests. */ - static final String MODEL_PATH = "models/codellama-7b.Q2_K.gguf"; + public static final String MODEL_PATH = "models/codellama-7b.Q2_K.gguf"; /** Path to the draft model used for speculative decoding tests. */ - static final String DRAFT_MODEL_PATH = "models/AMD-Llama-135m-code.Q2_K.gguf"; + public static final String DRAFT_MODEL_PATH = "models/AMD-Llama-135m-code.Q2_K.gguf"; /** Path to the Qwen3 thinking model used for reasoning budget tests. */ - static final String REASONING_MODEL_PATH = "models/Qwen3-0.6B-Q4_K_M.gguf"; + public static final String REASONING_MODEL_PATH = "models/Qwen3-0.6B-Q4_K_M.gguf"; /** * System property holding a path to a Nomic embedding model @@ -28,10 +30,10 @@ class TestConstants { * issue #98 (BERT-encoder result_output assertion) stays resolved. * When the property is unset the test self-skips. */ - static final String PROP_NOMIC_MODEL_PATH = LlamaSystemProperties.PREFIX + ".nomic.path"; + public static final String PROP_NOMIC_MODEL_PATH = LlamaSystemProperties.PREFIX + ".nomic.path"; /** Expected embedding dimension of nomic-embed-text-v1.5 (hidden size = 768). */ - static final int NOMIC_EMBED_DIM = 768; + public static final int NOMIC_EMBED_DIM = 768; /** * System property holding a path to a vision-capable model GGUF. Consumed by @@ -39,10 +41,10 @@ class TestConstants { * SmolVLM-500M Q8_0 GGUF; the test self-skips when the property is unset or * the file is missing. */ - static final String PROP_VISION_MODEL_PATH = LlamaSystemProperties.PREFIX + ".vision.model"; + public static final String PROP_VISION_MODEL_PATH = LlamaSystemProperties.PREFIX + ".vision.model"; /** System property holding a path to the matching mmproj GGUF for the vision model. */ - static final String PROP_VISION_MMPROJ_PATH = LlamaSystemProperties.PREFIX + ".vision.mmproj"; + public static final String PROP_VISION_MMPROJ_PATH = LlamaSystemProperties.PREFIX + ".vision.mmproj"; /** * System property holding a path to an image used as the visual prompt in @@ -52,12 +54,12 @@ class TestConstants { * works; the matching extension drives MIME detection in * {@code ContentPart.imageFile(Path)}. */ - static final String PROP_VISION_IMAGE_PATH = LlamaSystemProperties.PREFIX + ".vision.image"; + public static final String PROP_VISION_IMAGE_PATH = LlamaSystemProperties.PREFIX + ".vision.image"; /** * Path used by {@code MultimodalIntegrationTest} when * {@link #PROP_VISION_IMAGE_PATH} is unset. Points at the committed test * resource so the test needs no network access for the visual prompt. */ - static final String DEFAULT_VISION_IMAGE_PATH = "src/test/resources/images/test-image.jpg"; + public static final String DEFAULT_VISION_IMAGE_PATH = "src/test/resources/images/test-image.jpg"; } diff --git a/src/test/java/net/ladenthin/llama/args/ContinuationModeTest.java b/src/test/java/net/ladenthin/llama/args/ContinuationModeTest.java new file mode 100644 index 00000000..08ecf11a --- /dev/null +++ b/src/test/java/net/ladenthin/llama/args/ContinuationModeTest.java @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.args; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +import org.junit.jupiter.api.Test; + +public class ContinuationModeTest { + + @Test + public void getValueReturnsWireFormatStrings() { + // Pinning the exact wire strings kills the empty-string return mutant on getValue(). + assertThat(ContinuationMode.REASONING_CONTENT.getValue(), is("reasoning_content")); + assertThat(ContinuationMode.CONTENT.getValue(), is("content")); + } +} diff --git a/src/test/java/net/ladenthin/llama/benchmark/InferenceParametersBenchmark.java b/src/test/java/net/ladenthin/llama/benchmark/InferenceParametersBenchmark.java index ccce4444..1f49745c 100644 --- a/src/test/java/net/ladenthin/llama/benchmark/InferenceParametersBenchmark.java +++ b/src/test/java/net/ladenthin/llama/benchmark/InferenceParametersBenchmark.java @@ -4,7 +4,7 @@ package net.ladenthin.llama.benchmark; import java.util.concurrent.TimeUnit; -import net.ladenthin.llama.InferenceParameters; +import net.ladenthin.llama.parameters.InferenceParameters; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; diff --git a/src/test/java/net/ladenthin/llama/CancellationTokenLincheckTest.java b/src/test/java/net/ladenthin/llama/callback/CancellationTokenLincheckTest.java similarity index 97% rename from src/test/java/net/ladenthin/llama/CancellationTokenLincheckTest.java rename to src/test/java/net/ladenthin/llama/callback/CancellationTokenLincheckTest.java index 4119e832..4a53bb30 100644 --- a/src/test/java/net/ladenthin/llama/CancellationTokenLincheckTest.java +++ b/src/test/java/net/ladenthin/llama/callback/CancellationTokenLincheckTest.java @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2026 Bernard Ladenthin // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; import org.jetbrains.kotlinx.lincheck.LinChecker; import org.jetbrains.lincheck.datastructures.ModelCheckingOptions; diff --git a/src/test/java/net/ladenthin/llama/CancellationTokenTest.java b/src/test/java/net/ladenthin/llama/callback/CancellationTokenTest.java similarity index 96% rename from src/test/java/net/ladenthin/llama/CancellationTokenTest.java rename to src/test/java/net/ladenthin/llama/callback/CancellationTokenTest.java index 49ebfd6a..1469368f 100644 --- a/src/test/java/net/ladenthin/llama/CancellationTokenTest.java +++ b/src/test/java/net/ladenthin/llama/callback/CancellationTokenTest.java @@ -2,11 +2,12 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated( diff --git a/src/test/java/net/ladenthin/llama/LoadProgressCallbackTest.java b/src/test/java/net/ladenthin/llama/callback/LoadProgressCallbackTest.java similarity index 93% rename from src/test/java/net/ladenthin/llama/LoadProgressCallbackTest.java rename to src/test/java/net/ladenthin/llama/callback/LoadProgressCallbackTest.java index b3ec4b3f..99e2bd29 100644 --- a/src/test/java/net/ladenthin/llama/LoadProgressCallbackTest.java +++ b/src/test/java/net/ladenthin/llama/callback/LoadProgressCallbackTest.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.callback; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -11,6 +11,11 @@ import java.util.ArrayList; import java.util.List; +import net.ladenthin.llama.ClaudeGenerated; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.TestConstants; +import net.ladenthin.llama.exception.LlamaException; +import net.ladenthin.llama.parameters.ModelParameters; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/LlamaExceptionTest.java b/src/test/java/net/ladenthin/llama/exception/LlamaExceptionTest.java similarity index 65% rename from src/test/java/net/ladenthin/llama/LlamaExceptionTest.java rename to src/test/java/net/ladenthin/llama/exception/LlamaExceptionTest.java index 7386b5de..5df7fb5a 100644 --- a/src/test/java/net/ladenthin/llama/LlamaExceptionTest.java +++ b/src/test/java/net/ladenthin/llama/exception/LlamaExceptionTest.java @@ -3,10 +3,14 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.exception; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -18,25 +22,25 @@ public class LlamaExceptionTest { @Test public void testMessageIsPreserved() { LlamaException ex = new LlamaException("something went wrong"); - assertEquals("something went wrong", ex.getMessage()); + assertThat(ex.getMessage(), is("something went wrong")); } @Test public void testIsRuntimeException() { LlamaException ex = new LlamaException("error"); - assertTrue(ex instanceof RuntimeException); + assertThat(ex, is(instanceOf(RuntimeException.class))); } @Test public void testEmptyMessage() { LlamaException ex = new LlamaException(""); - assertEquals("", ex.getMessage()); + assertThat(ex.getMessage(), is("")); } @Test public void testNullMessage() { LlamaException ex = new LlamaException(null); - assertNull(ex.getMessage()); + assertThat(ex.getMessage(), is(nullValue())); } @Test @@ -45,9 +49,9 @@ public void testCanBeThrown() { try { throw new LlamaException("thrown"); } catch (LlamaException e) { - assertEquals("thrown", e.getMessage()); + assertThat(e.getMessage(), is("thrown")); caught = true; } - assertTrue(caught, "Expected LlamaException to be thrown"); + assertThat("Expected LlamaException to be thrown", caught, is(true)); } } diff --git a/src/test/java/net/ladenthin/llama/exception/ModelUnavailableExceptionTest.java b/src/test/java/net/ladenthin/llama/exception/ModelUnavailableExceptionTest.java new file mode 100644 index 00000000..2969ab01 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/exception/ModelUnavailableExceptionTest.java @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// SPDX-FileCopyrightText: 2023-2025 Konstantin Herud +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.exception; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; + +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +@ClaudeGenerated( + purpose = "Verify the typed-exception unification shape of ModelUnavailableException: the " + + "(message) and (message, cause) constructor matrix, that it is a typed subclass of " + + "LlamaException (so callers can catch it by the common base), and that it can be " + + "thrown and caught.") +public class ModelUnavailableExceptionTest { + + @Test + public void testMessageIsPreserved() { + ModelUnavailableException ex = new ModelUnavailableException("model file missing"); + assertThat(ex.getMessage(), is("model file missing")); + } + + @Test + public void testMessageAndCausePreserved() { + Throwable cause = new IllegalStateException("skip-download set"); + ModelUnavailableException ex = new ModelUnavailableException("model file missing", cause); + assertThat(ex.getMessage(), is("model file missing")); + assertThat(ex.getCause(), is(sameInstance(cause))); + } + + @Test + public void testIsLlamaException() { + ModelUnavailableException ex = new ModelUnavailableException("error"); + assertThat(ex, is(instanceOf(LlamaException.class))); + } + + @Test + public void testIsRuntimeException() { + ModelUnavailableException ex = new ModelUnavailableException("error"); + assertThat(ex, is(instanceOf(RuntimeException.class))); + } + + @Test + public void testNullMessage() { + ModelUnavailableException ex = new ModelUnavailableException(null); + assertThat(ex.getMessage(), is(nullValue())); + } + + @Test + public void testCanBeCaughtAsLlamaException() { + boolean caught = false; + try { + throw new ModelUnavailableException("thrown"); + } catch (LlamaException e) { + assertThat(e.getMessage(), is("thrown")); + caught = true; + } + assertThat("Expected ModelUnavailableException to be catchable as LlamaException", caught, is(true)); + } +} diff --git a/src/test/java/net/ladenthin/llama/jcstress/CancellationTokenRace.java b/src/test/java/net/ladenthin/llama/jcstress/CancellationTokenRace.java index d5815489..70afe902 100644 --- a/src/test/java/net/ladenthin/llama/jcstress/CancellationTokenRace.java +++ b/src/test/java/net/ladenthin/llama/jcstress/CancellationTokenRace.java @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT package net.ladenthin.llama.jcstress; -import net.ladenthin.llama.CancellationToken; +import net.ladenthin.llama.callback.CancellationToken; import org.openjdk.jcstress.annotations.Actor; import org.openjdk.jcstress.annotations.Arbiter; import org.openjdk.jcstress.annotations.Description; diff --git a/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java b/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java index 57c640df..0a7a875c 100644 --- a/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java +++ b/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java @@ -9,6 +9,12 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.List; +import net.ladenthin.llama.value.ChatChoice; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.ToolCall; +import nl.altindag.log.LogCaptor; import org.junit.jupiter.api.Test; /** @@ -211,4 +217,136 @@ public void testCountChoices_absent() throws Exception { JsonNode node = MAPPER.readTree("{\"id\":\"x\"}"); assertEquals(0, parser.countChoices(node)); } + + // ------------------------------------------------------------------ + // parseResponse(String) — full typed parse + // ------------------------------------------------------------------ + + @Test + public void testParseResponse_fullResponse() { + String json = "{\"id\":\"chatcmpl-abc\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"Hi there\"}," + + "\"finish_reason\":\"stop\"}]," + + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}}"; + ChatResponse r = parser.parseResponse(json); + + assertEquals("chatcmpl-abc", r.getId()); + assertEquals(1, r.getChoices().size()); + ChatChoice c = r.getChoices().get(0); + assertEquals(0, c.getIndex()); + assertEquals("assistant", c.getMessage().getRole()); + assertEquals("Hi there", c.getMessage().getContent()); + assertEquals("stop", c.getFinishReason()); + assertEquals(7L, r.getUsage().getPromptTokens()); + assertEquals(3L, r.getUsage().getCompletionTokens()); + assertEquals(json, r.getRawJson()); + } + + @Test + public void testParseResponse_multipleChoicesPreserveIndexAndOrder() { + String json = "{\"id\":\"x\",\"choices\":[" + + "{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"first\"},\"finish_reason\":\"stop\"}," + + "{\"index\":1,\"message\":{\"role\":\"assistant\",\"content\":\"second\"},\"finish_reason\":\"length\"}" + + "]}"; + ChatResponse r = parser.parseResponse(json); + + assertEquals(2, r.getChoices().size()); + assertEquals(0, r.getChoices().get(0).getIndex()); + assertEquals("first", r.getChoices().get(0).getMessage().getContent()); + assertEquals(1, r.getChoices().get(1).getIndex()); + assertEquals("second", r.getChoices().get(1).getMessage().getContent()); + assertEquals("length", r.getChoices().get(1).getFinishReason()); + } + + @Test + public void testParseResponse_toolCallsWithStringArguments() { + String json = "{\"id\":\"x\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"\"," + + "\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\"," + + "\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"city\\\":\\\"NYC\\\"}\"}}]}," + + "\"finish_reason\":\"tool_calls\"}]}"; + ChatResponse r = parser.parseResponse(json); + + ChatMessage m = r.getChoices().get(0).getMessage(); + List tcs = m.getToolCalls(); + assertEquals(1, tcs.size()); + assertEquals("call_1", tcs.get(0).getId()); + assertEquals("get_weather", tcs.get(0).getName()); + // arguments is a JSON string in the wire form → unwrapped verbatim, not re-quoted. + assertEquals("{\"city\":\"NYC\"}", tcs.get(0).getArgumentsJson()); + } + + @Test + public void testParseResponse_toolCallsWithObjectArguments() { + // Some shapes emit arguments as a nested object rather than a string; + // the parser serialises it back to its JSON text. + String json = "{\"id\":\"x\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"\"," + + "\"tool_calls\":[{\"id\":\"call_2\"," + + "\"function\":{\"name\":\"f\",\"arguments\":{\"a\":1}}}]}}]}"; + ChatResponse r = parser.parseResponse(json); + + ToolCall tc = r.getChoices().get(0).getMessage().getToolCalls().get(0); + assertEquals("{\"a\":1}", tc.getArgumentsJson()); + } + + @Test + public void testParseResponse_noToolCalls_plainAssistantMessage() { + String json = "{\"id\":\"x\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"plain\"}}]}"; + ChatResponse r = parser.parseResponse(json); + + ChatMessage m = r.getChoices().get(0).getMessage(); + assertEquals("plain", m.getContent()); + assertTrue(m.getToolCalls().isEmpty(), "plain message carries no tool calls"); + } + + @Test + public void testParseResponse_emptyChoicesArray_returnsMutableEmptyList() { + ChatResponse r = parser.parseResponse("{\"id\":\"x\",\"choices\":[]}"); + assertTrue(r.getChoices().isEmpty()); + // The choices list is exposed by reference and documented as mutable — + // adding to it must not throw (kills the immutable-emptyList() mutant). + r.getChoices().add(new ChatChoice(0, new ChatMessage("assistant", "added"), "stop")); + assertEquals(1, r.getChoices().size()); + } + + @Test + public void testParseResponse_absentChoices_returnsEmptyList() { + ChatResponse r = parser.parseResponse("{\"id\":\"x\"}"); + assertEquals("x", r.getId()); + assertTrue(r.getChoices().isEmpty()); + } + + @Test + public void testParseResponse_malformedJson_returnsEmptyResponsePreservingRawJson() { + String bad = "{not valid json"; + ChatResponse r = parser.parseResponse(bad); + assertEquals("", r.getId()); + assertTrue(r.getChoices().isEmpty()); + assertEquals(0L, r.getUsage().getPromptTokens()); + assertEquals(0L, r.getUsage().getCompletionTokens()); + // Raw JSON is preserved verbatim even on parse failure (escape hatch). + assertEquals(bad, r.getRawJson()); + } + + /** + * Parsing a response carrying real timings must emit exactly one per-run + * timing line through the dedicated SLF4J logger — pins the {@code + * TimingsLogger.log(...)} side-effect so its removal (VoidMethodCall mutant) + * is detected. + */ + @Test + public void testParseResponse_emitsTimingLine() { + String json = "{\"id\":\"x\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"ok\"}}]," + + "\"timings\":{\"prompt_n\":7,\"prompt_ms\":10.0,\"prompt_per_second\":700.0," + + "\"predicted_n\":3,\"predicted_ms\":20.0,\"predicted_per_second\":150.0}}"; + + try (LogCaptor captor = LogCaptor.forName(TimingsLogger.LOGGER_NAME)) { + ChatResponse r = parser.parseResponse(json); + assertEquals(7, r.getTimings().getPromptN()); + assertEquals(1, captor.getInfoLogs().size(), "exactly one timing line must be emitted"); + } + } } diff --git a/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java b/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java index fcdc5f1b..1d7e7149 100644 --- a/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java +++ b/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java @@ -9,9 +9,14 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Collections; +import java.util.List; import java.util.Map; -import net.ladenthin.llama.LlamaOutput; -import net.ladenthin.llama.StopReason; +import net.ladenthin.llama.value.CompletionResult; +import net.ladenthin.llama.value.LlamaOutput; +import net.ladenthin.llama.value.StopReason; +import net.ladenthin.llama.value.TokenLogprob; +import nl.altindag.log.LogCaptor; import org.junit.jupiter.api.Test; /** @@ -202,4 +207,141 @@ public void testParseProbabilities_topProbs_notIncluded() throws Exception { assertTrue(probs.containsKey("A"), "only outer token 'A' should be present"); assertFalse(probs.containsKey("B"), "inner top_probs token 'B' must not appear"); } + + // ------------------------------------------------------------------ + // parseLogprobs — typed per-token entries + // ------------------------------------------------------------------ + + @Test + public void testParseLogprobs_postSamplingWithNestedTopProbs() throws Exception { + String json = "{\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"id\":15043,\"prob\":0.82," + + "\"top_probs\":[{\"token\":\"Hi\",\"id\":9932,\"prob\":0.1}]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + List lp = parser.parseLogprobs(node); + + assertEquals(1, lp.size()); + TokenLogprob e = lp.get(0); + assertEquals("Hello", e.getToken()); + assertEquals(15043, e.getTokenId()); + assertEquals(0.82f, e.getLogprob(), 0.001f); + // Nested alternatives are parsed recursively from top_probs. + assertEquals(1, e.getTopLogprobs().size()); + assertEquals("Hi", e.getTopLogprobs().get(0).getToken()); + assertEquals(9932, e.getTopLogprobs().get(0).getTokenId()); + assertEquals(0.1f, e.getTopLogprobs().get(0).getLogprob(), 0.001f); + } + + @Test + public void testParseLogprobs_preSamplingUsesLogprobAndTopLogprobs() throws Exception { + // No "prob"/"top_probs" — the parser falls back to "logprob"/"top_logprobs". + String json = "{\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"id\":15043,\"logprob\":-0.2," + + "\"top_logprobs\":[{\"token\":\"Hi\",\"id\":9932,\"logprob\":-2.3}]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + List lp = parser.parseLogprobs(node); + + assertEquals(1, lp.size()); + TokenLogprob e = lp.get(0); + assertEquals(-0.2f, e.getLogprob(), 0.001f); + assertEquals(1, e.getTopLogprobs().size()); + assertEquals("Hi", e.getTopLogprobs().get(0).getToken()); + assertEquals(-2.3f, e.getTopLogprobs().get(0).getLogprob(), 0.001f); + } + + @Test + public void testParseLogprobs_entryWithoutAlternatives_hasEmptyTopLogprobs() throws Exception { + String json = "{\"completion_probabilities\":[" + "{\"token\":\"x\",\"id\":1,\"prob\":0.5}" + "]}"; + JsonNode node = MAPPER.readTree(json); + List lp = parser.parseLogprobs(node); + + assertEquals(1, lp.size()); + assertEquals(1, lp.get(0).getTokenId()); + assertTrue(lp.get(0).getTopLogprobs().isEmpty(), "no top_probs/top_logprobs → empty alternatives"); + } + + @Test + public void testParseLogprobs_missingId_defaultsToMinusOne() throws Exception { + String json = "{\"completion_probabilities\":[" + "{\"token\":\"x\",\"prob\":0.5}" + "]}"; + JsonNode node = MAPPER.readTree(json); + List lp = parser.parseLogprobs(node); + assertEquals(-1, lp.get(0).getTokenId()); + } + + @Test + public void testParseLogprobs_absentArray_returnsMutableEmptyList() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"hi\",\"stop\":true}"); + List lp = parser.parseLogprobs(node); + assertTrue(lp.isEmpty()); + // Documented to be a mutable empty list — adding must not throw + // (kills the immutable-emptyList() return mutant). + lp.add(new TokenLogprob("x", 1, 0.5f, Collections.emptyList())); + assertEquals(1, lp.size()); + } + + @Test + public void testParseLogprobs_emptyArray_returnsEmptyList() throws Exception { + JsonNode node = MAPPER.readTree("{\"completion_probabilities\":[]}"); + assertTrue(parser.parseLogprobs(node).isEmpty()); + } + + // ------------------------------------------------------------------ + // parseCompletionResult(String) — non-streaming typed result + // ------------------------------------------------------------------ + + @Test + public void testParseCompletionResult_fullResult() { + String json = "{\"content\":\"final answer\"," + + "\"tokens_evaluated\":11,\"tokens_predicted\":4," + + "\"stop_type\":\"eos\"," + + "\"completion_probabilities\":[{\"token\":\"final\",\"id\":1,\"prob\":0.7}]}"; + CompletionResult r = parser.parseCompletionResult(json); + + assertEquals("final answer", r.getText()); + assertEquals(11L, r.getUsage().getPromptTokens()); + assertEquals(4L, r.getUsage().getCompletionTokens()); + assertEquals(StopReason.EOS, r.getStopReason()); + assertEquals(1, r.getLogprobs().size()); + assertEquals("final", r.getLogprobs().get(0).getToken()); + assertEquals(json, r.getRawJson()); + } + + @Test + public void testParseCompletionResult_limitStopType() { + String json = "{\"content\":\"trunc\",\"tokens_evaluated\":2,\"tokens_predicted\":8,\"stop_type\":\"limit\"}"; + CompletionResult r = parser.parseCompletionResult(json); + assertEquals(StopReason.MAX_TOKENS, r.getStopReason()); + assertTrue(r.getLogprobs().isEmpty()); + } + + @Test + public void testParseCompletionResult_malformedJson_returnsEmptyResultPreservingRawJson() { + String bad = "{not valid json"; + CompletionResult r = parser.parseCompletionResult(bad); + assertEquals("", r.getText()); + assertEquals(0L, r.getUsage().getPromptTokens()); + assertEquals(0L, r.getUsage().getCompletionTokens()); + assertEquals(StopReason.NONE, r.getStopReason()); + assertTrue(r.getLogprobs().isEmpty()); + assertEquals(bad, r.getRawJson()); + } + + /** + * Parsing a completion result carrying real timings must emit exactly one + * per-run timing line — pins the {@code TimingsLogger.log(...)} side-effect. + */ + @Test + public void testParseCompletionResult_emitsTimingLine() { + String json = "{\"content\":\"done\",\"tokens_evaluated\":7,\"tokens_predicted\":3,\"stop_type\":\"eos\"," + + "\"timings\":{\"prompt_n\":7,\"prompt_ms\":10.0,\"prompt_per_second\":700.0," + + "\"predicted_n\":3,\"predicted_ms\":20.0,\"predicted_per_second\":150.0}}"; + + try (LogCaptor captor = LogCaptor.forName(TimingsLogger.LOGGER_NAME)) { + CompletionResult r = parser.parseCompletionResult(json); + assertEquals(7, r.getTimings().getPromptN()); + assertEquals(1, captor.getInfoLogs().size(), "exactly one timing line must be emitted"); + } + } } diff --git a/src/test/java/net/ladenthin/llama/json/ParameterJsonSerializerTest.java b/src/test/java/net/ladenthin/llama/json/ParameterJsonSerializerTest.java index 33b95a87..6ed085cd 100644 --- a/src/test/java/net/ladenthin/llama/json/ParameterJsonSerializerTest.java +++ b/src/test/java/net/ladenthin/llama/json/ParameterJsonSerializerTest.java @@ -5,7 +5,10 @@ package net.ladenthin.llama.json; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; @@ -15,8 +18,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import net.ladenthin.llama.Pair; import net.ladenthin.llama.args.Sampler; +import net.ladenthin.llama.parameters.ParameterJsonSerializer; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.Test; /** @@ -33,42 +37,42 @@ public class ParameterJsonSerializerTest { @Test public void testToJsonString_simple() { - assertEquals("\"hello\"", serializer.toJsonString("hello")); + assertThat(serializer.toJsonString("hello"), is("\"hello\"")); } @Test public void testToJsonString_null() { - assertEquals("null", serializer.toJsonString(null)); + assertThat(serializer.toJsonString(null), is("null")); } @Test public void testToJsonString_emptyString() { - assertEquals("\"\"", serializer.toJsonString("")); + assertThat(serializer.toJsonString(""), is("\"\"")); } @Test public void testToJsonString_newline() { - assertEquals("\"line1\\nline2\"", serializer.toJsonString("line1\nline2")); + assertThat(serializer.toJsonString("line1\nline2"), is("\"line1\\nline2\"")); } @Test public void testToJsonString_tab() { - assertEquals("\"a\\tb\"", serializer.toJsonString("a\tb")); + assertThat(serializer.toJsonString("a\tb"), is("\"a\\tb\"")); } @Test public void testToJsonString_quote() { - assertEquals("\"say \\\"hi\\\"\"", serializer.toJsonString("say \"hi\"")); + assertThat(serializer.toJsonString("say \"hi\""), is("\"say \\\"hi\\\"\"")); } @Test public void testToJsonString_backslash() { - assertEquals("\"path\\\\file\"", serializer.toJsonString("path\\file")); + assertThat(serializer.toJsonString("path\\file"), is("\"path\\\\file\"")); } @Test public void testToJsonString_unicode() { - assertEquals("\"café\"", serializer.toJsonString("café")); + assertThat(serializer.toJsonString("café"), is("\"café\"")); } // ------------------------------------------------------------------ @@ -79,11 +83,11 @@ public void testToJsonString_unicode() { public void testBuildMessages_withSystemMessage() { List> msgs = Collections.singletonList(new Pair<>("user", "Hello")); ArrayNode arr = serializer.buildMessages("You are helpful.", msgs); - assertEquals(2, arr.size()); - assertEquals("system", arr.get(0).path("role").asText()); - assertEquals("You are helpful.", arr.get(0).path("content").asText()); - assertEquals("user", arr.get(1).path("role").asText()); - assertEquals("Hello", arr.get(1).path("content").asText()); + assertThat(arr.size(), is(2)); + assertThat(arr.get(0).path("role").asText(), is("system")); + assertThat(arr.get(0).path("content").asText(), is("You are helpful.")); + assertThat(arr.get(1).path("role").asText(), is("user")); + assertThat(arr.get(1).path("content").asText(), is("Hello")); } @Test @@ -91,24 +95,24 @@ public void testBuildMessages_withoutSystemMessage() { List> msgs = Arrays.asList(new Pair<>("user", "Hi"), new Pair<>("assistant", "Hello there")); ArrayNode arr = serializer.buildMessages(null, msgs); - assertEquals(2, arr.size()); - assertEquals("user", arr.get(0).path("role").asText()); - assertEquals("assistant", arr.get(1).path("role").asText()); + assertThat(arr.size(), is(2)); + assertThat(arr.get(0).path("role").asText(), is("user")); + assertThat(arr.get(1).path("role").asText(), is("assistant")); } @Test public void testBuildMessages_emptySystemMessage_skipped() { List> msgs = Collections.singletonList(new Pair<>("user", "Hi")); ArrayNode arr = serializer.buildMessages("", msgs); - assertEquals(1, arr.size()); - assertEquals("user", arr.get(0).path("role").asText()); + assertThat(arr.size(), is(1)); + assertThat(arr.get(0).path("role").asText(), is("user")); } @Test public void testBuildMessages_specialCharsInContent() { List> msgs = Collections.singletonList(new Pair<>("user", "line1\nline2\t\"quoted\"")); ArrayNode arr = serializer.buildMessages(null, msgs); - assertEquals("line1\nline2\t\"quoted\"", arr.get(0).path("content").asText()); + assertThat(arr.get(0).path("content").asText(), is("line1\nline2\t\"quoted\"")); } @Test @@ -123,10 +127,10 @@ public void testBuildMessages_roundtripsAsJson() throws Exception { ArrayNode arr = serializer.buildMessages("Sys", msgs); String json = arr.toString(); JsonNode parsed = serializer.OBJECT_MAPPER.readTree(json); - assertEquals("system", parsed.get(0).path("role").asText()); - assertEquals("Sys", parsed.get(0).path("content").asText()); - assertEquals("user", parsed.get(1).path("role").asText()); - assertEquals("Hello", parsed.get(1).path("content").asText()); + assertThat(parsed.get(0).path("role").asText(), is("system")); + assertThat(parsed.get(0).path("content").asText(), is("Sys")); + assertThat(parsed.get(1).path("role").asText(), is("user")); + assertThat(parsed.get(1).path("content").asText(), is("Hello")); } // ------------------------------------------------------------------ @@ -136,31 +140,31 @@ public void testBuildMessages_roundtripsAsJson() throws Exception { @Test public void testBuildStopStrings_single() { ArrayNode arr = serializer.buildStopStrings("<|endoftext|>"); - assertEquals(1, arr.size()); - assertEquals("<|endoftext|>", arr.get(0).asText()); + assertThat(arr.size(), is(1)); + assertThat(arr.get(0).asText(), is("<|endoftext|>")); } @Test public void testBuildStopStrings_multiple() { ArrayNode arr = serializer.buildStopStrings("stop1", "stop2", "stop3"); - assertEquals(3, arr.size()); - assertEquals("stop1", arr.get(0).asText()); - assertEquals("stop3", arr.get(2).asText()); + assertThat(arr.size(), is(3)); + assertThat(arr.get(0).asText(), is("stop1")); + assertThat(arr.get(2).asText(), is("stop3")); } @Test public void testBuildStopStrings_withSpecialChars() { ArrayNode arr = serializer.buildStopStrings("line\nnewline", "tab\there"); - assertEquals("line\nnewline", arr.get(0).asText()); - assertEquals("tab\there", arr.get(1).asText()); + assertThat(arr.get(0).asText(), is("line\nnewline")); + assertThat(arr.get(1).asText(), is("tab\there")); } @Test public void testBuildStopStrings_roundtripsAsJson() throws Exception { ArrayNode arr = serializer.buildStopStrings("a", "b"); JsonNode parsed = serializer.OBJECT_MAPPER.readTree(arr.toString()); - assertTrue(parsed.isArray()); - assertEquals("a", parsed.get(0).asText()); + assertThat(parsed.isArray(), is(true)); + assertThat(parsed.get(0).asText(), is("a")); } // ------------------------------------------------------------------ @@ -170,18 +174,18 @@ public void testBuildStopStrings_roundtripsAsJson() throws Exception { @Test public void testBuildSamplers_allTypes() { ArrayNode arr = serializer.buildSamplers(Sampler.TOP_K, Sampler.TOP_P, Sampler.MIN_P, Sampler.TEMPERATURE); - assertEquals(4, arr.size()); - assertEquals("top_k", arr.get(0).asText()); - assertEquals("top_p", arr.get(1).asText()); - assertEquals("min_p", arr.get(2).asText()); - assertEquals("temperature", arr.get(3).asText()); + assertThat(arr.size(), is(4)); + assertThat(arr.get(0).asText(), is("top_k")); + assertThat(arr.get(1).asText(), is("top_p")); + assertThat(arr.get(2).asText(), is("min_p")); + assertThat(arr.get(3).asText(), is("temperature")); } @Test public void testBuildSamplers_single() { ArrayNode arr = serializer.buildSamplers(Sampler.TEMPERATURE); - assertEquals(1, arr.size()); - assertEquals("temperature", arr.get(0).asText()); + assertThat(arr.size(), is(1)); + assertThat(arr.get(0).asText(), is("temperature")); } // ------------------------------------------------------------------ @@ -191,23 +195,23 @@ public void testBuildSamplers_single() { @Test public void testBuildIntArray_values() { ArrayNode arr = serializer.buildIntArray(new int[] {1, 2, 3}); - assertEquals(3, arr.size()); - assertEquals(1, arr.get(0).asInt()); - assertEquals(3, arr.get(2).asInt()); + assertThat(arr.size(), is(3)); + assertThat(arr.get(0).asInt(), is(1)); + assertThat(arr.get(2).asInt(), is(3)); } @Test public void testBuildIntArray_empty() { ArrayNode arr = serializer.buildIntArray(new int[] {}); - assertEquals(0, arr.size()); + assertThat(arr.size(), is(0)); } @Test public void testBuildIntArray_roundtripsAsJson() throws Exception { ArrayNode arr = serializer.buildIntArray(new int[] {10, 20}); JsonNode parsed = serializer.OBJECT_MAPPER.readTree(arr.toString()); - assertTrue(parsed.isArray()); - assertEquals(10, parsed.get(0).asInt()); + assertThat(parsed.isArray(), is(true)); + assertThat(parsed.get(0).asInt(), is(10)); } // ------------------------------------------------------------------ @@ -220,17 +224,17 @@ public void testBuildTokenIdBiasArray_structure() { biases.put(15043, 1.0f); biases.put(50256, -0.5f); ArrayNode arr = serializer.buildTokenIdBiasArray(biases); - assertEquals(2, arr.size()); - assertEquals(15043, arr.get(0).get(0).asInt()); + assertThat(arr.size(), is(2)); + assertThat(arr.get(0).get(0).asInt(), is(15043)); assertEquals(1.0, arr.get(0).get(1).asDouble(), 0.001); - assertEquals(50256, arr.get(1).get(0).asInt()); + assertThat(arr.get(1).get(0).asInt(), is(50256)); assertEquals(-0.5, arr.get(1).get(1).asDouble(), 0.001); } @Test public void testBuildTokenIdBiasArray_empty() { ArrayNode arr = serializer.buildTokenIdBiasArray(Collections.emptyMap()); - assertEquals(0, arr.size()); + assertThat(arr.size(), is(0)); } // ------------------------------------------------------------------ @@ -243,10 +247,10 @@ public void testBuildTokenStringBiasArray_structure() { biases.put("Hello", 1.0f); biases.put(" world", -0.5f); ArrayNode arr = serializer.buildTokenStringBiasArray(biases); - assertEquals(2, arr.size()); - assertEquals("Hello", arr.get(0).get(0).asText()); + assertThat(arr.size(), is(2)); + assertThat(arr.get(0).get(0).asText(), is("Hello")); assertEquals(1.0, arr.get(0).get(1).asDouble(), 0.001); - assertEquals(" world", arr.get(1).get(0).asText()); + assertThat(arr.get(1).get(0).asText(), is(" world")); } @Test @@ -254,7 +258,7 @@ public void testBuildTokenStringBiasArray_specialCharsInKey() { Map biases = new LinkedHashMap<>(); biases.put("line\nnewline", 2.0f); ArrayNode arr = serializer.buildTokenStringBiasArray(biases); - assertEquals("line\nnewline", arr.get(0).get(0).asText()); + assertThat(arr.get(0).get(0).asText(), is("line\nnewline")); } // ------------------------------------------------------------------ @@ -264,17 +268,17 @@ public void testBuildTokenStringBiasArray_specialCharsInKey() { @Test public void testBuildDisableTokenIdArray_structure() { ArrayNode arr = serializer.buildDisableTokenIdArray(Arrays.asList(100, 200, 300)); - assertEquals(3, arr.size()); + assertThat(arr.size(), is(3)); for (int i = 0; i < arr.size(); i++) { - assertFalse(arr.get(i).get(1).asBoolean()); + assertThat(arr.get(i).get(1).asBoolean(), is(false)); } - assertEquals(100, arr.get(0).get(0).asInt()); + assertThat(arr.get(0).get(0).asInt(), is(100)); } @Test public void testBuildDisableTokenIdArray_empty() { ArrayNode arr = serializer.buildDisableTokenIdArray(Collections.emptyList()); - assertEquals(0, arr.size()); + assertThat(arr.size(), is(0)); } // ------------------------------------------------------------------ @@ -284,10 +288,10 @@ public void testBuildDisableTokenIdArray_empty() { @Test public void testBuildDisableTokenStringArray_structure() { ArrayNode arr = serializer.buildDisableTokenStringArray(Arrays.asList("foo", "bar")); - assertEquals(2, arr.size()); - assertEquals("foo", arr.get(0).get(0).asText()); - assertFalse(arr.get(0).get(1).asBoolean()); - assertEquals("bar", arr.get(1).get(0).asText()); + assertThat(arr.size(), is(2)); + assertThat(arr.get(0).get(0).asText(), is("foo")); + assertThat(arr.get(0).get(1).asBoolean(), is(false)); + assertThat(arr.get(1).get(0).asText(), is("bar")); } // ------------------------------------------------------------------ @@ -298,8 +302,8 @@ public void testBuildDisableTokenStringArray_structure() { public void testBuildRawValueObject_booleanValue() { Map map = Collections.singletonMap("enable_thinking", "true"); ObjectNode node = serializer.buildRawValueObject(map); - assertTrue(node.path("enable_thinking").isBoolean()); - assertTrue(node.path("enable_thinking").asBoolean()); + assertThat(node.path("enable_thinking").isBoolean(), is(true)); + assertThat(node.path("enable_thinking").asBoolean(), is(true)); } @Test @@ -313,14 +317,14 @@ public void testBuildRawValueObject_numberValue() { public void testBuildRawValueObject_stringValue() { Map map = Collections.singletonMap("mode", "\"fast\""); ObjectNode node = serializer.buildRawValueObject(map); - assertEquals("fast", node.path("mode").asText()); + assertThat(node.path("mode").asText(), is("fast")); } @Test public void testBuildRawValueObject_invalidJsonFallsBackToString() { Map map = Collections.singletonMap("key", "not-valid-json{{{"); ObjectNode node = serializer.buildRawValueObject(map); - assertEquals("not-valid-json{{{", node.path("key").asText()); + assertThat(node.path("key").asText(), is("not-valid-json{{{")); } @Test @@ -330,7 +334,7 @@ public void testBuildRawValueObject_roundtripsAsJson() throws Exception { map.put("count", "3"); ObjectNode node = serializer.buildRawValueObject(map); JsonNode parsed = serializer.OBJECT_MAPPER.readTree(node.toString()); - assertTrue(parsed.path("flag").asBoolean()); - assertEquals(3, parsed.path("count").asInt()); + assertThat(parsed.path("flag").asBoolean(), is(true)); + assertThat(parsed.path("count").asInt(), is(3)); } } diff --git a/src/test/java/net/ladenthin/llama/json/RerankResponseParserTest.java b/src/test/java/net/ladenthin/llama/json/RerankResponseParserTest.java index c74f4bfa..84ef9cdc 100644 --- a/src/test/java/net/ladenthin/llama/json/RerankResponseParserTest.java +++ b/src/test/java/net/ladenthin/llama/json/RerankResponseParserTest.java @@ -5,12 +5,16 @@ package net.ladenthin.llama.json; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.List; -import net.ladenthin.llama.Pair; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.Test; /** @@ -30,8 +34,8 @@ public class RerankResponseParserTest { public void testParseString_singleEntry() { String json = "[{\"document\":\"The quick brown fox\",\"index\":0,\"score\":0.92}]"; List> result = parser.parse(json); - assertEquals(1, result.size()); - assertEquals("The quick brown fox", result.get(0).getKey()); + assertThat(result, hasSize(1)); + assertThat(result.get(0).getKey(), is("The quick brown fox")); assertEquals(0.92f, result.get(0).getValue(), 0.001f); } @@ -42,10 +46,10 @@ public void testParseString_multipleEntries() { + "{\"document\":\"Third\",\"index\":2,\"score\":0.1}" + "]"; List> result = parser.parse(json); - assertEquals(3, result.size()); - assertEquals("First", result.get(0).getKey()); - assertEquals("Second", result.get(1).getKey()); - assertEquals("Third", result.get(2).getKey()); + assertThat(result, hasSize(3)); + assertThat(result.get(0).getKey(), is("First")); + assertThat(result.get(1).getKey(), is("Second")); + assertThat(result.get(2).getKey(), is("Third")); assertEquals(0.9f, result.get(0).getValue(), 0.001f); assertEquals(0.5f, result.get(1).getValue(), 0.001f); assertEquals(0.1f, result.get(2).getValue(), 0.001f); @@ -54,34 +58,34 @@ public void testParseString_multipleEntries() { @Test public void testParseString_emptyArray() { List> result = parser.parse("[]"); - assertTrue(result.isEmpty()); + assertThat(result, is(empty())); } @Test public void testParseString_malformed() { List> result = parser.parse("{not json"); - assertTrue(result.isEmpty()); + assertThat(result, is(empty())); } @Test public void testParseString_notAnArray() { List> result = parser.parse("{\"document\":\"x\",\"score\":0.5}"); - assertTrue(result.isEmpty()); + assertThat(result, is(empty())); } @Test public void testParseString_documentWithSpecialChars() { String json = "[{\"document\":\"line1\\nline2\\t\\\"quoted\\\"\",\"index\":0,\"score\":0.75}]"; List> result = parser.parse(json); - assertEquals(1, result.size()); - assertEquals("line1\nline2\t\"quoted\"", result.get(0).getKey()); + assertThat(result, hasSize(1)); + assertThat(result.get(0).getKey(), is("line1\nline2\t\"quoted\"")); } @Test public void testParseString_scoreZero() { String json = "[{\"document\":\"irrelevant\",\"index\":0,\"score\":0.0}]"; List> result = parser.parse(json); - assertEquals(1, result.size()); + assertThat(result, hasSize(1)); assertEquals(0.0f, result.get(0).getValue(), 0.001f); } @@ -96,22 +100,34 @@ public void testParseNode_preservesOrder() throws Exception { + "]"; JsonNode arr = MAPPER.readTree(json); List> result = parser.parse(arr); - assertEquals(2, result.size()); - assertEquals("A", result.get(0).getKey()); - assertEquals("B", result.get(1).getKey()); + assertThat(result, hasSize(2)); + assertThat(result.get(0).getKey(), is("A")); + assertThat(result.get(1).getKey(), is("B")); } @Test public void testParseNode_notArray() throws Exception { JsonNode obj = MAPPER.readTree("{\"document\":\"x\",\"score\":0.5}"); - assertTrue(parser.parse(obj).isEmpty()); + assertThat(parser.parse(obj), is(empty())); + } + + @Test + public void testParseNode_notArray_returnsMutableEmptyList() throws Exception { + // The non-array branch returns a MUTABLE empty list (matches the non-empty path, + // for Error Prone MixedMutabilityReturnType). Mutating it must succeed — which also + // kills the EmptyObjectReturnVals mutant that would return an immutable emptyList(). + JsonNode obj = MAPPER.readTree("{\"document\":\"x\",\"score\":0.5}"); + List> result = parser.parse(obj); + assertThat(result, is(empty())); + result.add(new Pair<>("added", 1.0f)); + assertThat(result, hasSize(1)); } @Test public void testParseNode_missingScore_defaultsToZero() throws Exception { JsonNode arr = MAPPER.readTree("[{\"document\":\"doc\",\"index\":0}]"); List> result = parser.parse(arr); - assertEquals(1, result.size()); + assertThat(result, hasSize(1)); assertEquals(0.0f, result.get(0).getValue(), 0.001f); } @@ -119,7 +135,7 @@ public void testParseNode_missingScore_defaultsToZero() throws Exception { public void testParseNode_missingDocument_defaultsToEmpty() throws Exception { JsonNode arr = MAPPER.readTree("[{\"index\":0,\"score\":0.5}]"); List> result = parser.parse(arr); - assertEquals(1, result.size()); - assertEquals("", result.get(0).getKey()); + assertThat(result, hasSize(1)); + assertThat(result.get(0).getKey(), is("")); } } diff --git a/src/test/java/net/ladenthin/llama/TimingsLoggerTest.java b/src/test/java/net/ladenthin/llama/json/TimingsLoggerTest.java similarity index 97% rename from src/test/java/net/ladenthin/llama/TimingsLoggerTest.java rename to src/test/java/net/ladenthin/llama/json/TimingsLoggerTest.java index 5f15d259..3e2e2e8b 100644 --- a/src/test/java/net/ladenthin/llama/TimingsLoggerTest.java +++ b/src/test/java/net/ladenthin/llama/json/TimingsLoggerTest.java @@ -2,12 +2,14 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.json; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import net.ladenthin.llama.ClaudeGenerated; +import net.ladenthin.llama.value.Timings; import nl.altindag.log.LogCaptor; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/LlamaLoaderTest.java b/src/test/java/net/ladenthin/llama/loader/LlamaLoaderTest.java similarity index 85% rename from src/test/java/net/ladenthin/llama/LlamaLoaderTest.java rename to src/test/java/net/ladenthin/llama/loader/LlamaLoaderTest.java index 14aab11d..d3fd3fb4 100644 --- a/src/test/java/net/ladenthin/llama/LlamaLoaderTest.java +++ b/src/test/java/net/ladenthin/llama/loader/LlamaLoaderTest.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import static org.junit.jupiter.api.Assertions.*; @@ -12,6 +12,7 @@ import java.io.File; import java.io.IOException; import java.nio.file.Paths; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -186,4 +187,24 @@ public void testGetNativeResourcePathContainsOsAndArch() { String osArch = OSInfo.getNativeLibFolderPathForCurrentOS(); assertTrue(path.endsWith(osArch), "Resource path should end with OS/arch: " + path); } + + /** + * Regression for the layered-restructure bug: the native-library classpath + * root is fixed at {@code /net/ladenthin/llama//} by CMakeLists + + * the publish workflow, so it must NOT track the loader's own Java package + * (which moved to {@code net.ladenthin.llama.loader}). Deriving it from + * {@code LlamaLoader.class.getPackage()} produced {@code .../llama/loader/...}, + * one level too deep, so {@code getResource(...)} returned null and every + * native-backed test failed with "No native library found". + */ + @Test + public void testGetNativeResourcePathIsPackageIndependent() { + String path = LlamaLoader.getNativeResourcePath(); + String osArch = OSInfo.getNativeLibFolderPathForCurrentOS(); + assertEquals("/net/ladenthin/llama/" + osArch, path); + assertFalse( + path.contains("/loader/"), + "Resource path must not include the loader subpackage — the native libs live at " + + "/net/ladenthin/llama//, not under the loader package: " + path); + } } diff --git a/src/test/java/net/ladenthin/llama/LoggingSmokeTest.java b/src/test/java/net/ladenthin/llama/loader/LoggingSmokeTest.java similarity index 96% rename from src/test/java/net/ladenthin/llama/LoggingSmokeTest.java rename to src/test/java/net/ladenthin/llama/loader/LoggingSmokeTest.java index 9fb193ed..4849e653 100644 --- a/src/test/java/net/ladenthin/llama/LoggingSmokeTest.java +++ b/src/test/java/net/ladenthin/llama/loader/LoggingSmokeTest.java @@ -2,12 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; +import net.ladenthin.llama.ClaudeGenerated; import nl.altindag.log.LogCaptor; import org.junit.jupiter.api.Test; import org.slf4j.LoggerFactory; diff --git a/src/test/java/net/ladenthin/llama/loader/NativeLibraryLoadSmokeTest.java b/src/test/java/net/ladenthin/llama/loader/NativeLibraryLoadSmokeTest.java new file mode 100644 index 00000000..9ec3c555 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/loader/NativeLibraryLoadSmokeTest.java @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.loader; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +/** + * Model-free smoke test that the bundled native library actually loads and its + * {@code JNI_OnLoad} resolves every Java class it looks up by name. + * + *

    Forcing {@code LlamaModel.} runs + * {@code LlamaLoader.initialize() -> System.load() -> JNI_OnLoad}, which calls + * {@code FindClass(...)} for the JNI-referenced classes ({@code LlamaException}, + * {@code LogLevel}, {@code LogFormat}, ...). No GGUF model is required, so this + * catches the two failure modes that the model-gated tests cannot exercise when + * models are absent (e.g. in a restricted-network sandbox): + * + *

      + *
    • a wrong native-resource path in {@link LlamaLoader} (lib not found), and
    • + *
    • a stale {@code FindClass} FQN in {@code jllama.cpp} after a Java package + * move (lib loads but {@code JNI_OnLoad} throws + * {@code NoClassDefFoundError}).
    • + *
    + * + *

    Both bugs shipped once on this branch precisely because they only surface + * when the library is loaded — see the regression history in {@code CLAUDE.md}. + * + *

    The test self-skips when {@code libjllama} is not on the classpath (a + * pure-Java checkout with no native build), so a plain {@code mvn test} stays + * green without a CMake build; CI's {@code test-java-*} jobs and any local build + * have the library and run it for real. The presence check uses the canonical + * resource layout directly (not {@link LlamaLoader#getNativeResourcePath()}) so + * a regression in that method cannot silently skip this guard. + */ +@ClaudeGenerated( + purpose = "Model-free native-load smoke: force LlamaModel. so System.load + JNI_OnLoad " + + "run and resolve every FindClass'd Java class. Guards against native-resource-path and " + + "stale-JNI-FQN regressions that only appear when the library is actually loaded; skips " + + "cleanly when libjllama is not on the classpath.") +class NativeLibraryLoadSmokeTest { + + private static boolean nativeLibraryOnClasspath() { + String resource = "/net/ladenthin/llama/" + OSInfo.getNativeLibFolderPathForCurrentOS() + "/" + + System.mapLibraryName("jllama"); + return NativeLibraryLoadSmokeTest.class.getResource(resource) != null; + } + + @Test + void loadingNativeLibraryRunsJniOnLoadWithoutError() { + assumeTrue(nativeLibraryOnClasspath(), "libjllama not on classpath — skipping native-load smoke"); + assertDoesNotThrow( + () -> Class.forName("net.ladenthin.llama.LlamaModel"), + "LlamaModel. must load the native library and JNI_OnLoad must resolve " + + "every FindClass'd Java class"); + } +} diff --git a/src/test/java/net/ladenthin/llama/NativeLibraryPermissionSetterTest.java b/src/test/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetterTest.java similarity index 97% rename from src/test/java/net/ladenthin/llama/NativeLibraryPermissionSetterTest.java rename to src/test/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetterTest.java index ed238fc4..7d9fcf71 100644 --- a/src/test/java/net/ladenthin/llama/NativeLibraryPermissionSetterTest.java +++ b/src/test/java/net/ladenthin/llama/loader/NativeLibraryPermissionSetterTest.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -12,6 +12,7 @@ import java.io.ByteArrayOutputStream; import java.io.File; import java.io.PrintStream; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated( diff --git a/src/test/java/net/ladenthin/llama/OSInfoTest.java b/src/test/java/net/ladenthin/llama/loader/OSInfoTest.java similarity index 98% rename from src/test/java/net/ladenthin/llama/OSInfoTest.java rename to src/test/java/net/ladenthin/llama/loader/OSInfoTest.java index 764f5c92..36e15295 100644 --- a/src/test/java/net/ladenthin/llama/OSInfoTest.java +++ b/src/test/java/net/ladenthin/llama/loader/OSInfoTest.java @@ -3,10 +3,11 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.loader; import static org.junit.jupiter.api.Assertions.*; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/ChatRequestTest.java b/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java similarity index 59% rename from src/test/java/net/ladenthin/llama/ChatRequestTest.java rename to src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java index cde53682..388d0502 100644 --- a/src/test/java/net/ladenthin/llama/ChatRequestTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java @@ -2,16 +2,18 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertSame; +package net.ladenthin.llama.parameters; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import net.ladenthin.llama.value.ToolDefinition; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -32,43 +34,43 @@ class Immutability { void appendMessageReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); ChatRequest derived = original.appendMessage("user", "hi"); - assertNotSame(original, derived); - assertEquals(0, original.getMessages().size(), "original is untouched"); - assertEquals(1, derived.getMessages().size(), "derived has the message"); + assertThat(derived, is(not(sameInstance(original)))); + assertThat("original is untouched", original.getMessages(), is(empty())); + assertThat("derived has the message", derived.getMessages(), hasSize(1)); } @Test void appendToolReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); ChatRequest derived = original.appendTool(new ToolDefinition("echo", "Echo", "{}")); - assertNotSame(original, derived); - assertEquals(0, original.getTools().size()); - assertEquals(1, derived.getTools().size()); + assertThat(derived, is(not(sameInstance(original)))); + assertThat(original.getTools(), is(empty())); + assertThat(derived.getTools(), hasSize(1)); } @Test void withToolChoiceReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); ChatRequest derived = original.withToolChoice("auto"); - assertNotSame(original, derived); - assertFalse(original.getToolChoice().isPresent(), "original toolChoice unset"); - assertEquals("auto", derived.getToolChoice().orElseThrow()); + assertThat(derived, is(not(sameInstance(original)))); + assertThat("original toolChoice unset", original.getToolChoice().isPresent(), is(false)); + assertThat(derived.getToolChoice().orElseThrow(), is("auto")); } @Test void withMaxToolRoundsReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); ChatRequest derived = original.withMaxToolRounds(2); - assertNotSame(original, derived); - assertEquals(ChatRequest.DEFAULT_MAX_TOOL_ROUNDS, original.getMaxToolRounds()); - assertEquals(2, derived.getMaxToolRounds()); + assertThat(derived, is(not(sameInstance(original)))); + assertThat(original.getMaxToolRounds(), is(ChatRequest.DEFAULT_MAX_TOOL_ROUNDS)); + assertThat(derived.getMaxToolRounds(), is(2)); } @Test void withInferenceCustomizerReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); ChatRequest derived = original.withInferenceCustomizer(p -> p.withSeed(42)); - assertNotSame(original, derived); + assertThat(derived, is(not(sameInstance(original)))); } @Test @@ -79,26 +81,28 @@ void chainedDerivationsLeaveIntermediatesUntouched() { ChatRequest c = b.appendMessage("assistant", "hello"); ChatRequest d = c.withMaxToolRounds(3); - assertEquals(0, a.getMessages().size()); - assertEquals(1, b.getMessages().size()); - assertEquals(2, c.getMessages().size()); - assertEquals(2, d.getMessages().size()); - assertEquals(ChatRequest.DEFAULT_MAX_TOOL_ROUNDS, c.getMaxToolRounds()); - assertEquals(3, d.getMaxToolRounds()); + assertThat(a.getMessages(), is(empty())); + assertThat(b.getMessages(), hasSize(1)); + assertThat(c.getMessages(), hasSize(2)); + assertThat(d.getMessages(), hasSize(2)); + assertThat(c.getMaxToolRounds(), is(ChatRequest.DEFAULT_MAX_TOOL_ROUNDS)); + assertThat(d.getMaxToolRounds(), is(3)); } @Test @DisplayName("the messages accessor returns an unmodifiable view") void messagesAccessorIsUnmodifiable() { ChatRequest req = ChatRequest.empty().appendMessage("user", "hi"); - assertThrows(UnsupportedOperationException.class, () -> req.getMessages().clear()); + assertThrows( + UnsupportedOperationException.class, () -> req.getMessages().clear()); } @Test @DisplayName("the tools accessor returns an unmodifiable view") void toolsAccessorIsUnmodifiable() { ChatRequest req = ChatRequest.empty().appendTool(new ToolDefinition("e", "d", "{}")); - assertThrows(UnsupportedOperationException.class, () -> req.getTools().clear()); + assertThrows( + UnsupportedOperationException.class, () -> req.getTools().clear()); } } @@ -108,37 +112,38 @@ class Equality { @Test void twoEmptyRequestsAreEqual() { - assertEquals(ChatRequest.empty(), ChatRequest.empty()); + assertThat(ChatRequest.empty(), is(ChatRequest.empty())); } @Test void sameContentSameEquality() { ChatRequest a = ChatRequest.empty().appendMessage("user", "hi").withMaxToolRounds(3); ChatRequest b = ChatRequest.empty().appendMessage("user", "hi").withMaxToolRounds(3); - assertEquals(a, b); - assertEquals(a.hashCode(), b.hashCode()); + assertThat(a, is(b)); + assertThat(a.hashCode(), is(b.hashCode())); } @Test void differentMessagesNotEqual() { ChatRequest a = ChatRequest.empty().appendMessage("user", "hi"); ChatRequest b = ChatRequest.empty().appendMessage("user", "bye"); - assertNotEquals(a, b); + assertThat(a, is(not(b))); } @Test void differentMaxToolRoundsNotEqual() { ChatRequest a = ChatRequest.empty().withMaxToolRounds(2); ChatRequest b = ChatRequest.empty().withMaxToolRounds(3); - assertNotEquals(a, b); + assertThat(a, is(not(b))); } @Test - @DisplayName("the customiser is excluded from equality — two requests with the same content but different lambdas are equal") + @DisplayName( + "the customiser is excluded from equality — two requests with the same content but different lambdas are equal") void customizerExcludedFromEquality() { ChatRequest a = ChatRequest.empty().withInferenceCustomizer(p -> p.withSeed(1)); ChatRequest b = ChatRequest.empty().withInferenceCustomizer(p -> p.withSeed(2)); - assertEquals(a, b, "different lambda identities must NOT make the requests unequal"); + assertThat("different lambda identities must NOT make the requests unequal", a, is(b)); } } @@ -148,17 +153,19 @@ class Validation { @Test void withMaxToolRoundsRejectsZero() { - assertThrows(IllegalArgumentException.class, () -> ChatRequest.empty().withMaxToolRounds(0)); + assertThrows( + IllegalArgumentException.class, () -> ChatRequest.empty().withMaxToolRounds(0)); } @Test void withMaxToolRoundsRejectsNegative() { - assertThrows(IllegalArgumentException.class, () -> ChatRequest.empty().withMaxToolRounds(-1)); + assertThrows( + IllegalArgumentException.class, () -> ChatRequest.empty().withMaxToolRounds(-1)); } @Test void emptyMessageIsTheCanonicalStartingPoint() { - assertSame(ChatRequest.empty(), ChatRequest.empty(), "empty() is a cached singleton"); + assertThat("empty() is a cached singleton", ChatRequest.empty(), is(sameInstance(ChatRequest.empty()))); } } @@ -170,13 +177,13 @@ class JsonHelpers { void buildMessagesJsonDoesNotMutate() { ChatRequest req = ChatRequest.empty().appendMessage("user", "hi"); String json = req.buildMessagesJson(); - assertTrue(json.contains("\"user\""), json); - assertEquals(1, req.getMessages().size(), "build did not mutate the messages list"); + assertThat(json, json, containsString("\"user\"")); + assertThat("build did not mutate the messages list", req.getMessages(), hasSize(1)); } @Test void buildToolsJsonEmptyWhenNoTools() { - assertFalse(ChatRequest.empty().buildToolsJson().isPresent()); + assertThat(ChatRequest.empty().buildToolsJson().isPresent(), is(false)); } } } diff --git a/src/test/java/net/ladenthin/llama/InferenceParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java similarity index 71% rename from src/test/java/net/ladenthin/llama/InferenceParametersTest.java rename to src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java index add91850..33b7f494 100644 --- a/src/test/java/net/ladenthin/llama/InferenceParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java @@ -3,19 +3,31 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.*; +package net.ladenthin.llama.parameters; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.hamcrest.Matchers.startsWith; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.args.ContinuationMode; import net.ladenthin.llama.args.MiroStat; import net.ladenthin.llama.args.ReasoningFormat; import net.ladenthin.llama.args.Sampler; +import net.ladenthin.llama.value.Pair; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -34,21 +46,21 @@ public class InferenceParametersTest { @Test public void testConstructorSetsPrompt() { InferenceParameters params = new InferenceParameters("hello"); - assertTrue(params.parameters.containsKey("prompt")); - assertEquals("\"hello\"", params.parameters.get("prompt")); + assertThat(params.parameters, hasKey("prompt")); + assertThat(params.parameters.get("prompt"), is("\"hello\"")); } @Test public void testConstructorWithEmptyPrompt() { InferenceParameters params = new InferenceParameters(""); - assertEquals("\"\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"\"")); } @Test public void testSetPromptOverrides() { InferenceParameters params = new InferenceParameters("first"); params = params.withPrompt("second"); - assertEquals("\"second\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"second\"")); } // ------------------------------------------------------------------------- @@ -58,121 +70,121 @@ public void testSetPromptOverrides() { @Test public void testSetNPredict() { InferenceParameters params = new InferenceParameters("").withNPredict(42); - assertEquals("42", params.parameters.get("n_predict")); + assertThat(params.parameters.get("n_predict"), is("42")); } @Test public void testSetTemperature() { InferenceParameters params = new InferenceParameters("").withTemperature(0.5f); - assertEquals("0.5", params.parameters.get("temperature")); + assertThat(params.parameters.get("temperature"), is("0.5")); } @Test public void testSetTopK() { InferenceParameters params = new InferenceParameters("").withTopK(10); - assertEquals("10", params.parameters.get("top_k")); + assertThat(params.parameters.get("top_k"), is("10")); } @Test public void testSetTopP() { InferenceParameters params = new InferenceParameters("").withTopP(0.9f); - assertEquals("0.9", params.parameters.get("top_p")); + assertThat(params.parameters.get("top_p"), is("0.9")); } @Test public void testSetMinP() { InferenceParameters params = new InferenceParameters("").withMinP(0.1f); - assertEquals("0.1", params.parameters.get("min_p")); + assertThat(params.parameters.get("min_p"), is("0.1")); } @Test public void testSetTfsZ() { InferenceParameters params = new InferenceParameters("").withTfsZ(1.0f); - assertEquals("1.0", params.parameters.get("tfs_z")); + assertThat(params.parameters.get("tfs_z"), is("1.0")); } @Test public void testSetTypicalP() { InferenceParameters params = new InferenceParameters("").withTypicalP(0.8f); - assertEquals("0.8", params.parameters.get("typical_p")); + assertThat(params.parameters.get("typical_p"), is("0.8")); } @Test public void testSetRepeatLastN() { InferenceParameters params = new InferenceParameters("").withRepeatLastN(64); - assertEquals("64", params.parameters.get("repeat_last_n")); + assertThat(params.parameters.get("repeat_last_n"), is("64")); } @Test public void testSetRepeatPenalty() { InferenceParameters params = new InferenceParameters("").withRepeatPenalty(1.1f); - assertEquals("1.1", params.parameters.get("repeat_penalty")); + assertThat(params.parameters.get("repeat_penalty"), is("1.1")); } @Test public void testSetFrequencyPenalty() { InferenceParameters params = new InferenceParameters("").withFrequencyPenalty(0.2f); - assertEquals("0.2", params.parameters.get("frequency_penalty")); + assertThat(params.parameters.get("frequency_penalty"), is("0.2")); } @Test public void testSetPresencePenalty() { InferenceParameters params = new InferenceParameters("").withPresencePenalty(0.3f); - assertEquals("0.3", params.parameters.get("presence_penalty")); + assertThat(params.parameters.get("presence_penalty"), is("0.3")); } @Test public void testSetSeed() { InferenceParameters params = new InferenceParameters("").withSeed(1234); - assertEquals("1234", params.parameters.get("seed")); + assertThat(params.parameters.get("seed"), is("1234")); } @Test public void testSetNProbs() { InferenceParameters params = new InferenceParameters("").withNProbs(5); - assertEquals("5", params.parameters.get("n_probs")); + assertThat(params.parameters.get("n_probs"), is("5")); } @Test public void testSetMinKeep() { InferenceParameters params = new InferenceParameters("").withMinKeep(2); - assertEquals("2", params.parameters.get("min_keep")); + assertThat(params.parameters.get("min_keep"), is("2")); } @Test public void testSetNKeep() { InferenceParameters params = new InferenceParameters("").withNKeep(-1); - assertEquals("-1", params.parameters.get("n_keep")); + assertThat(params.parameters.get("n_keep"), is("-1")); } @Test public void testSetCachePrompt() { InferenceParameters params = new InferenceParameters("").withCachePrompt(true); - assertEquals("true", params.parameters.get("cache_prompt")); + assertThat(params.parameters.get("cache_prompt"), is("true")); } @Test public void testSetIgnoreEos() { InferenceParameters params = new InferenceParameters("").withIgnoreEos(true); - assertEquals("true", params.parameters.get("ignore_eos")); + assertThat(params.parameters.get("ignore_eos"), is("true")); } @Test public void testSetPenalizeNl() { InferenceParameters params = new InferenceParameters("").withPenalizeNl(false); - assertEquals("false", params.parameters.get("penalize_nl")); + assertThat(params.parameters.get("penalize_nl"), is("false")); } @Test public void testSetDynamicTemperatureRange() { InferenceParameters params = new InferenceParameters("").withDynamicTemperatureRange(0.5f); - assertEquals("0.5", params.parameters.get("dynatemp_range")); + assertThat(params.parameters.get("dynatemp_range"), is("0.5")); } @Test public void testSetDynamicTemperatureExponent() { InferenceParameters params = new InferenceParameters("").withDynamicTemperatureExponent(2.0f); - assertEquals("2.0", params.parameters.get("dynatemp_exponent")); + assertThat(params.parameters.get("dynatemp_exponent"), is("2.0")); } // ------------------------------------------------------------------------- @@ -182,45 +194,45 @@ public void testSetDynamicTemperatureExponent() { @Test public void testSetInputPrefix() { InferenceParameters params = new InferenceParameters("").withInputPrefix("prefix"); - assertEquals("\"prefix\"", params.parameters.get("input_prefix")); + assertThat(params.parameters.get("input_prefix"), is("\"prefix\"")); } @Test public void testSetInputSuffix() { InferenceParameters params = new InferenceParameters("").withInputSuffix("suffix"); - assertEquals("\"suffix\"", params.parameters.get("input_suffix")); + assertThat(params.parameters.get("input_suffix"), is("\"suffix\"")); } @Test public void testSetGrammar() { InferenceParameters params = new InferenceParameters("").withGrammar("root ::= \"a\""); - assertEquals("\"root ::= \\\"a\\\"\"", params.parameters.get("grammar")); + assertThat(params.parameters.get("grammar"), is("\"root ::= \\\"a\\\"\"")); } @Test public void testSetJsonSchemaStoresVerbatim() { String schema = "{\"type\":\"object\",\"properties\":{\"name\":{\"type\":\"string\"}},\"required\":[\"name\"]}"; InferenceParameters params = new InferenceParameters("").withJsonSchema(schema); - assertEquals(schema, params.parameters.get("json_schema")); - assertTrue(params.toString().contains("\"json_schema\": " + schema)); + assertThat(params.parameters.get("json_schema"), is(schema)); + assertThat(params.toString(), containsString("\"json_schema\": " + schema)); } @Test public void testSetPenaltyPromptString() { InferenceParameters params = new InferenceParameters("").withPenaltyPrompt("Hello!"); - assertEquals("\"Hello!\"", params.parameters.get("penalty_prompt")); + assertThat(params.parameters.get("penalty_prompt"), is("\"Hello!\"")); } @Test public void testSetUseChatTemplate() { InferenceParameters params = new InferenceParameters("").withUseChatTemplate(true); - assertEquals("true", params.parameters.get("use_jinja")); + assertThat(params.parameters.get("use_jinja"), is("true")); } @Test public void testSetChatTemplate() { InferenceParameters params = new InferenceParameters("").withChatTemplate("{{messages}}"); - assertEquals("\"{{messages}}\"", params.parameters.get("chat_template")); + assertThat(params.parameters.get("chat_template"), is("\"{{messages}}\"")); } @Test @@ -230,16 +242,16 @@ public void testSetChatTemplateKwargs() { kwargs.put("max_tokens", "1024"); InferenceParameters params = new InferenceParameters("").withChatTemplateKwargs(kwargs); String value = params.parameters.get("chat_template_kwargs"); - assertNotNull(value); - assertTrue(value.contains("\"enable_thinking\":true")); - assertTrue(value.contains("\"max_tokens\":1024")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("\"enable_thinking\":true")); + assertThat(value, containsString("\"max_tokens\":1024")); } @Test public void testSetChatTemplateKwargsEmpty() { java.util.Map kwargs = new java.util.LinkedHashMap<>(); InferenceParameters params = new InferenceParameters("").withChatTemplateKwargs(kwargs); - assertEquals("{}", params.parameters.get("chat_template_kwargs")); + assertThat(params.parameters.get("chat_template_kwargs"), is("{}")); } // ------------------------------------------------------------------------- @@ -249,13 +261,13 @@ public void testSetChatTemplateKwargsEmpty() { @Test public void testSetTopNSigmaEnabled() { InferenceParameters params = new InferenceParameters("").withTopNSigma(2.0f); - assertEquals("2.0", params.parameters.get("top_n_sigma")); + assertThat(params.parameters.get("top_n_sigma"), is("2.0")); } @Test public void testSetTopNSigmaDisabled() { InferenceParameters params = new InferenceParameters("").withTopNSigma(-1.0f); - assertEquals("-1.0", params.parameters.get("top_n_sigma")); + assertThat(params.parameters.get("top_n_sigma"), is("-1.0")); } // ------------------------------------------------------------------------- @@ -265,68 +277,68 @@ public void testSetTopNSigmaDisabled() { @Test public void testSetReasoningFormatNone() { InferenceParameters params = new InferenceParameters("").withReasoningFormat(ReasoningFormat.NONE); - assertEquals("\"none\"", params.parameters.get("reasoning_format")); + assertThat(params.parameters.get("reasoning_format"), is("\"none\"")); } @Test public void testSetReasoningFormatAuto() { InferenceParameters params = new InferenceParameters("").withReasoningFormat(ReasoningFormat.AUTO); - assertEquals("\"auto\"", params.parameters.get("reasoning_format")); + assertThat(params.parameters.get("reasoning_format"), is("\"auto\"")); } @Test public void testSetReasoningFormatDeepseek() { InferenceParameters params = new InferenceParameters("").withReasoningFormat(ReasoningFormat.DEEPSEEK); - assertEquals("\"deepseek\"", params.parameters.get("reasoning_format")); + assertThat(params.parameters.get("reasoning_format"), is("\"deepseek\"")); } @Test public void testSetReasoningFormatDeepseekLegacy() { InferenceParameters params = new InferenceParameters("").withReasoningFormat(ReasoningFormat.DEEPSEEK_LEGACY); - assertEquals("\"deepseek-legacy\"", params.parameters.get("reasoning_format")); + assertThat(params.parameters.get("reasoning_format"), is("\"deepseek-legacy\"")); } @Test public void testSetReasoningBudgetTokensPositive() { InferenceParameters params = new InferenceParameters("").withReasoningBudgetTokens(512); - assertEquals("512", params.parameters.get("reasoning_budget_tokens")); + assertThat(params.parameters.get("reasoning_budget_tokens"), is("512")); } @Test public void testSetReasoningBudgetTokensZero() { InferenceParameters params = new InferenceParameters("").withReasoningBudgetTokens(0); - assertEquals("0", params.parameters.get("reasoning_budget_tokens")); + assertThat(params.parameters.get("reasoning_budget_tokens"), is("0")); } @Test public void testSetReasoningBudgetTokensDisabled() { InferenceParameters params = new InferenceParameters("").withReasoningBudgetTokens(-1); - assertEquals("-1", params.parameters.get("reasoning_budget_tokens")); + assertThat(params.parameters.get("reasoning_budget_tokens"), is("-1")); } @Test public void testSetContinueFinalMessageTrue() { InferenceParameters params = new InferenceParameters("").withContinueFinalMessage(true); - assertEquals("true", params.parameters.get("continue_final_message")); + assertThat(params.parameters.get("continue_final_message"), is("true")); } @Test public void testSetContinueFinalMessageFalse() { InferenceParameters params = new InferenceParameters("").withContinueFinalMessage(false); - assertEquals("false", params.parameters.get("continue_final_message")); + assertThat(params.parameters.get("continue_final_message"), is("false")); } @Test public void testSetContinueFinalMessageReasoningContent() { InferenceParameters params = new InferenceParameters("").withContinueFinalMessage(ContinuationMode.REASONING_CONTENT); - assertEquals("\"reasoning_content\"", params.parameters.get("continue_final_message")); + assertThat(params.parameters.get("continue_final_message"), is("\"reasoning_content\"")); } @Test public void testSetContinueFinalMessageContent() { InferenceParameters params = new InferenceParameters("").withContinueFinalMessage(ContinuationMode.CONTENT); - assertEquals("\"content\"", params.parameters.get("continue_final_message")); + assertThat(params.parameters.get("continue_final_message"), is("\"content\"")); } // ------------------------------------------------------------------------- @@ -336,31 +348,31 @@ public void testSetContinueFinalMessageContent() { @Test public void testSetMiroStatDisabled() { InferenceParameters params = new InferenceParameters("").withMiroStat(MiroStat.DISABLED); - assertEquals("0", params.parameters.get("mirostat")); + assertThat(params.parameters.get("mirostat"), is("0")); } @Test public void testSetMiroStatV1() { InferenceParameters params = new InferenceParameters("").withMiroStat(MiroStat.V1); - assertEquals("1", params.parameters.get("mirostat")); + assertThat(params.parameters.get("mirostat"), is("1")); } @Test public void testSetMiroStatV2() { InferenceParameters params = new InferenceParameters("").withMiroStat(MiroStat.V2); - assertEquals("2", params.parameters.get("mirostat")); + assertThat(params.parameters.get("mirostat"), is("2")); } @Test public void testSetMiroStatTau() { InferenceParameters params = new InferenceParameters("").withMiroStatTau(5.0f); - assertEquals("5.0", params.parameters.get("mirostat_tau")); + assertThat(params.parameters.get("mirostat_tau"), is("5.0")); } @Test public void testSetMiroStatEta() { InferenceParameters params = new InferenceParameters("").withMiroStatEta(0.1f); - assertEquals("0.1", params.parameters.get("mirostat_eta")); + assertThat(params.parameters.get("mirostat_eta"), is("0.1")); } // ------------------------------------------------------------------------- @@ -370,20 +382,20 @@ public void testSetMiroStatEta() { @Test public void testSetStopStringsSingle() { InferenceParameters params = new InferenceParameters("").withStopStrings("stop"); - assertEquals("[\"stop\"]", params.parameters.get("stop")); + assertThat(params.parameters.get("stop"), is("[\"stop\"]")); } @Test public void testSetStopStringsMultiple() { InferenceParameters params = new InferenceParameters("").withStopStrings("stop1", "stop2"); - assertEquals("[\"stop1\",\"stop2\"]", params.parameters.get("stop")); + assertThat(params.parameters.get("stop"), is("[\"stop1\",\"stop2\"]")); } @Test public void testSetStopStringsEmpty() { InferenceParameters params = new InferenceParameters(""); params = params.withStopStrings(); - assertFalse(params.parameters.containsKey("stop")); + assertThat(params.parameters, not(hasKey("stop"))); } // ------------------------------------------------------------------------- @@ -393,27 +405,27 @@ public void testSetStopStringsEmpty() { @Test public void testSetSamplersSingle() { InferenceParameters params = new InferenceParameters("").withSamplers(Sampler.TOP_K); - assertEquals("[\"top_k\"]", params.parameters.get("samplers")); + assertThat(params.parameters.get("samplers"), is("[\"top_k\"]")); } @Test public void testSetSamplersMultiple() { InferenceParameters params = new InferenceParameters("").withSamplers(Sampler.TOP_K, Sampler.TOP_P, Sampler.TEMPERATURE); - assertEquals("[\"top_k\",\"top_p\",\"temperature\"]", params.parameters.get("samplers")); + assertThat(params.parameters.get("samplers"), is("[\"top_k\",\"top_p\",\"temperature\"]")); } @Test public void testSetSamplersMinP() { InferenceParameters params = new InferenceParameters("").withSamplers(Sampler.MIN_P); - assertEquals("[\"min_p\"]", params.parameters.get("samplers")); + assertThat(params.parameters.get("samplers"), is("[\"min_p\"]")); } @Test public void testSetSamplersEmpty() { InferenceParameters params = new InferenceParameters(""); params = params.withSamplers(); - assertFalse(params.parameters.containsKey("samplers")); + assertThat(params.parameters, not(hasKey("samplers"))); } // ------------------------------------------------------------------------- @@ -425,15 +437,15 @@ public void testSetTokenIdBias() { Map bias = Collections.singletonMap(15043, 1.0f); InferenceParameters params = new InferenceParameters("").withTokenIdBias(bias); String value = params.parameters.get("logit_bias"); - assertNotNull(value); - assertTrue(value.contains("15043")); - assertTrue(value.contains("1.0")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("15043")); + assertThat(value, containsString("1.0")); } @Test public void testSetTokenIdBiasEmpty() { InferenceParameters params = new InferenceParameters("").withTokenIdBias(Collections.emptyMap()); - assertFalse(params.parameters.containsKey("logit_bias")); + assertThat(params.parameters, not(hasKey("logit_bias"))); } // ------------------------------------------------------------------------- @@ -445,15 +457,15 @@ public void testSetTokenBias() { Map bias = Collections.singletonMap(" Hello", 1.0f); InferenceParameters params = new InferenceParameters("").withTokenBias(bias); String value = params.parameters.get("logit_bias"); - assertNotNull(value); - assertTrue(value.contains("Hello")); - assertTrue(value.contains("1.0")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("Hello")); + assertThat(value, containsString("1.0")); } @Test public void testSetTokenBiasEmpty() { InferenceParameters params = new InferenceParameters("").withTokenBias(Collections.emptyMap()); - assertFalse(params.parameters.containsKey("logit_bias")); + assertThat(params.parameters, not(hasKey("logit_bias"))); } // ------------------------------------------------------------------------- @@ -464,30 +476,30 @@ public void testSetTokenBiasEmpty() { public void testDisableTokenIds() { InferenceParameters params = new InferenceParameters("").withDisabledTokenIds(Arrays.asList(1, 2, 3)); String value = params.parameters.get("logit_bias"); - assertNotNull(value); - assertTrue(value.contains("false")); - assertTrue(value.contains("1")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("false")); + assertThat(value, containsString("1")); } @Test public void testDisableTokenIdsEmpty() { InferenceParameters params = new InferenceParameters("").withDisabledTokenIds(Collections.emptyList()); - assertFalse(params.parameters.containsKey("logit_bias")); + assertThat(params.parameters, not(hasKey("logit_bias"))); } @Test public void testDisableTokens() { InferenceParameters params = new InferenceParameters("").withDisabledTokens(Arrays.asList("bad", "word")); String value = params.parameters.get("logit_bias"); - assertNotNull(value); - assertTrue(value.contains("false")); - assertTrue(value.contains("bad")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("false")); + assertThat(value, containsString("bad")); } @Test public void testDisableTokensEmpty() { InferenceParameters params = new InferenceParameters("").withDisabledTokens(Collections.emptyList()); - assertFalse(params.parameters.containsKey("logit_bias")); + assertThat(params.parameters, not(hasKey("logit_bias"))); } // ------------------------------------------------------------------------- @@ -497,14 +509,14 @@ public void testDisableTokensEmpty() { @Test public void testSetPenaltyPromptTokenIds() { InferenceParameters params = new InferenceParameters("").withPenaltyPrompt(new int[] {1, 2, 3}); - assertEquals("[1,2,3]", params.parameters.get("penalty_prompt")); + assertThat(params.parameters.get("penalty_prompt"), is("[1,2,3]")); } @Test public void testSetPenaltyPromptTokenIdsEmpty() { InferenceParameters params = new InferenceParameters(""); params = params.withPenaltyPrompt(new int[] {}); - assertFalse(params.parameters.containsKey("penalty_prompt")); + assertThat(params.parameters, not(hasKey("penalty_prompt"))); } // ------------------------------------------------------------------------- @@ -516,11 +528,11 @@ public void testSetMessagesWithSystemAndUserMessages() { List> messages = Collections.singletonList(new Pair<>("user", "Hi")); InferenceParameters params = new InferenceParameters("").withMessages("System msg", messages); String value = params.parameters.get("messages"); - assertNotNull(value); - assertTrue(value.contains("system")); - assertTrue(value.contains("System msg")); - assertTrue(value.contains("user")); - assertTrue(value.contains("Hi")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("system")); + assertThat(value, containsString("System msg")); + assertThat(value, containsString("user")); + assertThat(value, containsString("Hi")); } @Test @@ -529,9 +541,9 @@ public void testSetMessagesWithAssistantRole() { Arrays.asList(new Pair<>("user", "Hello"), new Pair<>("assistant", "Hi there")); InferenceParameters params = new InferenceParameters("").withMessages(null, messages); String value = params.parameters.get("messages"); - assertNotNull(value); - assertTrue(value.contains("assistant")); - assertTrue(value.contains("Hi there")); + assertThat(value, is(notNullValue())); + assertThat(value, containsString("assistant")); + assertThat(value, containsString("Hi there")); } @Test @@ -539,9 +551,9 @@ public void testSetMessagesNoSystemMessage() { List> messages = Collections.singletonList(new Pair<>("user", "Hello")); InferenceParameters params = new InferenceParameters("").withMessages(null, messages); String value = params.parameters.get("messages"); - assertNotNull(value); - assertFalse(value.contains("system")); - assertTrue(value.contains("user")); + assertThat(value, is(notNullValue())); + assertThat(value, not(containsString("system"))); + assertThat(value, containsString("user")); } @Test @@ -549,7 +561,7 @@ public void testSetMessagesEmptySystemMessage() { List> messages = Collections.singletonList(new Pair<>("user", "Hello")); InferenceParameters params = new InferenceParameters("").withMessages("", messages); String value = params.parameters.get("messages"); - assertFalse(value.contains("system")); + assertThat(value, not(containsString("system"))); } @Test @@ -572,10 +584,10 @@ public void testSetMessagesInvalidRoleOther() { public void testToStringContainsPrompt() { InferenceParameters params = new InferenceParameters("test prompt"); String json = params.toString(); - assertTrue(json.startsWith("{")); - assertTrue(json.endsWith("}")); - assertTrue(json.contains("\"prompt\"")); - assertTrue(json.contains("\"test prompt\"")); + assertThat(json, startsWith("{")); + assertThat(json, endsWith("}")); + assertThat(json, containsString("\"prompt\"")); + assertThat(json, containsString("\"test prompt\"")); } @Test @@ -583,8 +595,8 @@ public void testToStringWithMultipleParams() { InferenceParameters params = new InferenceParameters("p").withTemperature(0.7f).withTopK(20); String json = params.toString(); - assertTrue(json.contains("\"temperature\"")); - assertTrue(json.contains("\"top_k\"")); + assertThat(json, containsString("\"temperature\"")); + assertThat(json, containsString("\"top_k\"")); } // ------------------------------------------------------------------------- @@ -594,31 +606,31 @@ public void testToStringWithMultipleParams() { @Test public void testToJsonStringEscapesBackslash() { InferenceParameters params = new InferenceParameters("path\\to\\file"); - assertEquals("\"path\\\\to\\\\file\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"path\\\\to\\\\file\"")); } @Test public void testToJsonStringEscapesDoubleQuote() { InferenceParameters params = new InferenceParameters("say \"hi\""); - assertEquals("\"say \\\"hi\\\"\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"say \\\"hi\\\"\"")); } @Test public void testToJsonStringEscapesNewline() { InferenceParameters params = new InferenceParameters("line1\nline2"); - assertEquals("\"line1\\nline2\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"line1\\nline2\"")); } @Test public void testToJsonStringEscapesTab() { InferenceParameters params = new InferenceParameters("col1\tcol2"); - assertEquals("\"col1\\tcol2\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"col1\\tcol2\"")); } @Test public void testToJsonStringEscapesCarriageReturn() { InferenceParameters params = new InferenceParameters("a\rb"); - assertEquals("\"a\\rb\"", params.parameters.get("prompt")); + assertThat(params.parameters.get("prompt"), is("\"a\\rb\"")); } @Test @@ -626,7 +638,7 @@ public void testToJsonStringNull() { // toJsonString(null) returns null — only used internally but verify via grammar InferenceParameters params = new InferenceParameters(""); params = params.withGrammar(null); - assertNull(params.parameters.get("grammar")); + assertThat(params.parameters.get("grammar"), is(nullValue())); } @Test @@ -634,8 +646,8 @@ public void testToJsonStringSlashNotEscaped() { // Jackson does not escape '/' — forward slashes are passed through verbatim InferenceParameters params = new InferenceParameters(""); String value = params.parameters.get("prompt"); - assertTrue(value.contains("")); - assertFalse(value.contains("<\\/")); + assertThat(value, containsString("")); + assertThat(value, not(containsString("<\\/"))); } // ------------------------------------------------------------------------- @@ -645,9 +657,9 @@ public void testToJsonStringSlashNotEscaped() { @Test public void testBuilderChainingReturnsNewInstance() { InferenceParameters params = new InferenceParameters(""); - assertNotSame(params.withTemperature(0.5f), params); - assertNotSame(params.withTopK(10), params); - assertNotSame(params.withNPredict(5), params); + assertThat(params.withTemperature(0.5f), is(not(sameInstance(params)))); + assertThat(params.withTopK(10), is(not(sameInstance(params)))); + assertThat(params.withNPredict(5), is(not(sameInstance(params)))); } // ------------------------------------------------------------------------- @@ -657,13 +669,13 @@ public void testBuilderChainingReturnsNewInstance() { @Test public void testSetStreamTrue() { InferenceParameters params = new InferenceParameters("").withStream(true); - assertEquals("true", params.parameters.get("stream")); + assertThat(params.parameters.get("stream"), is("true")); } @Test public void testSetStreamFalse() { InferenceParameters params = new InferenceParameters("").withStream(false); - assertEquals("false", params.parameters.get("stream")); + assertThat(params.parameters.get("stream"), is("false")); } // ------------------------------------------------------------------------- @@ -677,10 +689,10 @@ public void testSetTokenIdBiasMultiple() { bias.put(2, -1.0f); InferenceParameters params = new InferenceParameters("").withTokenIdBias(bias); String value = params.parameters.get("logit_bias"); - assertNotNull(value); - assertTrue(value.startsWith("[")); - assertTrue(value.endsWith("]")); - assertTrue(value.contains("1")); - assertTrue(value.contains("2")); + assertThat(value, is(notNullValue())); + assertThat(value, startsWith("[")); + assertThat(value, endsWith("]")); + assertThat(value, containsString("1")); + assertThat(value, containsString("2")); } } diff --git a/src/test/java/net/ladenthin/llama/JsonEndpointParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java similarity index 79% rename from src/test/java/net/ladenthin/llama/JsonEndpointParametersTest.java rename to src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java index 678b67a0..3972d36e 100644 --- a/src/test/java/net/ladenthin/llama/JsonEndpointParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java @@ -3,11 +3,18 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import java.io.File; +import net.ladenthin.llama.ClaudeGenerated; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.TestConstants; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; @@ -75,8 +82,8 @@ public void testDryMultiplierAccepted() { + ",\"dry_multiplier\":0.8,\"dry_base\":1.75,\"dry_allowed_length\":2" + ",\"dry_penalty_last_n\":-1}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\""), "Response should contain 'content' field"); + assertThat(result, is(notNullValue())); + assertThat("Response should contain 'content' field", result, containsString("\"content\"")); } @Test @@ -85,8 +92,8 @@ public void testDrySequenceBreakersAccepted() { + DETERMINISTIC + ",\"dry_multiplier\":0.5,\"dry_sequence_breakers\":[\"\\n\",\":\",\"*\"]}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } @Test @@ -96,8 +103,8 @@ public void testDryDisabledByDefault() { + DETERMINISTIC + ",\"dry_multiplier\":0.0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -110,8 +117,8 @@ public void testXtcParametersAccepted() { + DETERMINISTIC + ",\"xtc_probability\":0.5,\"xtc_threshold\":0.1}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } @Test @@ -120,8 +127,8 @@ public void testXtcDisabled() { + DETERMINISTIC + ",\"xtc_probability\":0.0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -133,8 +140,8 @@ public void testTopNSigmaAccepted() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"top_n_sigma\":2.0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } @Test @@ -142,8 +149,8 @@ public void testTopNSigmaDisabled() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"top_n_sigma\":-1.0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -156,9 +163,9 @@ public void testReturnTokensTrue() { + DETERMINISTIC + ",\"return_tokens\":true}"; String result = model.handleCompletions(json); - assertNotNull(result); + assertThat(result, is(notNullValue())); // When return_tokens is true, the response should include a "tokens" array - assertTrue(result.contains("\"tokens\""), "Response should contain 'tokens' field"); + assertThat("Response should contain 'tokens' field", result, containsString("\"tokens\"")); } @Test @@ -167,8 +174,8 @@ public void testReturnTokensFalse() { + DETERMINISTIC + ",\"return_tokens\":false}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -181,9 +188,9 @@ public void testResponseFieldsFiltering() { + DETERMINISTIC + ",\"response_fields\":[\"content\",\"stop\"]}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); - assertTrue(result.contains("\"stop\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); + assertThat(result, containsString("\"stop\"")); } // ------------------------------------------------------------------------- @@ -196,10 +203,10 @@ public void testTimingsPerTokenTrue() { + DETERMINISTIC + ",\"timings_per_token\":true}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); // timings_per_token enables per-token timing info - assertTrue(result.contains("\"timings\""), "Response should contain timings"); + assertThat("Response should contain timings", result, containsString("\"timings\"")); } // ------------------------------------------------------------------------- @@ -212,11 +219,12 @@ public void testPostSamplingProbsWithNProbs() { + DETERMINISTIC + ",\"n_probs\":3,\"post_sampling_probs\":true}"; String result = model.handleCompletions(json); - assertNotNull(result); + assertThat(result, is(notNullValue())); // post_sampling_probs changes the label from "logprob" to "prob" - assertTrue( - result.contains("\"completion_probabilities\"") || result.contains("\"prob\""), - "Response should contain completion_probabilities"); + assertThat( + "Response should contain completion_probabilities", + result, + anyOf(containsString("\"completion_probabilities\""), containsString("\"prob\""))); } // ------------------------------------------------------------------------- @@ -227,8 +235,8 @@ public void testPostSamplingProbsWithNProbs() { public void testNDiscardAccepted() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"n_discard\":0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -239,9 +247,9 @@ public void testNDiscardAccepted() { public void testIdSlotSelection() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"id_slot\":0}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); - assertTrue(result.contains("\"id_slot\""), "Response should contain 'id_slot' field"); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); + assertThat("Response should contain 'id_slot' field", result, containsString("\"id_slot\"")); } // ------------------------------------------------------------------------- @@ -254,8 +262,8 @@ public void testIgnoreEosAccepted() { String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"ignore_eos\":true}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -271,8 +279,8 @@ public void testCombinedAdvancedSampling() { + ",\"xtc_probability\":0.3,\"xtc_threshold\":0.1" + ",\"repeat_penalty\":1.1,\"frequency_penalty\":0.1,\"presence_penalty\":0.1}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -285,8 +293,8 @@ public void testCustomSamplerChainViaJson() { + DETERMINISTIC + ",\"samplers\":[\"top_k\",\"top_p\",\"temperature\"]}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } // ------------------------------------------------------------------------- @@ -300,7 +308,7 @@ public void testSpeculativeParamsAccepted() { + DETERMINISTIC + ",\"speculative\":{\"n_min\":0,\"n_max\":16,\"p_min\":0.75}}"; String result = model.handleCompletions(json); - assertNotNull(result); - assertTrue(result.contains("\"content\"")); + assertThat(result, is(notNullValue())); + assertThat(result, containsString("\"content\"")); } } diff --git a/src/test/java/net/ladenthin/llama/JsonParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/JsonParametersTest.java similarity index 97% rename from src/test/java/net/ladenthin/llama/JsonParametersTest.java rename to src/test/java/net/ladenthin/llama/parameters/JsonParametersTest.java index b5a0a15d..195ed47b 100644 --- a/src/test/java/net/ladenthin/llama/JsonParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/JsonParametersTest.java @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.parameters; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotSame; @@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Map; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.args.CacheType; import net.ladenthin.llama.args.CliArg; import org.junit.jupiter.api.Test; @@ -110,9 +111,7 @@ public void withScalar_booleanFalse_storesLowercaseFalse() { @Test public void withScalar_overwritesPreviousValue() { - TestBuilder b = new TestBuilder() - .withScalarPublic("--threads", 4) - .withScalarPublic("--threads", 16); + TestBuilder b = new TestBuilder().withScalarPublic("--threads", 4).withScalarPublic("--threads", 16); assertEquals("16", b.parameters.get("--threads")); assertEquals(1, b.parameters.size()); } diff --git a/src/test/java/net/ladenthin/llama/ModelParametersExtendedTest.java b/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java similarity index 64% rename from src/test/java/net/ladenthin/llama/ModelParametersExtendedTest.java rename to src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java index 1f4dc4f2..752c1031 100644 --- a/src/test/java/net/ladenthin/llama/ModelParametersExtendedTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java @@ -3,13 +3,29 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.*; +package net.ladenthin.llama.parameters; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.hamcrest.Matchers.startsWith; import java.util.HashMap; import java.util.Map; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.args.*; +import net.ladenthin.llama.args.CacheType; +import net.ladenthin.llama.args.GpuSplitMode; +import net.ladenthin.llama.args.MiroStat; +import net.ladenthin.llama.args.NumaStrategy; +import net.ladenthin.llama.args.RopeScalingType; import org.junit.jupiter.api.Test; /** @@ -31,55 +47,55 @@ public class ModelParametersExtendedTest { @Test public void testSetCtxSize() { ModelParameters p = new ModelParameters().setCtxSize(2048); - assertEquals("2048", p.parameters.get("--ctx-size")); + assertThat(p.parameters.get("--ctx-size"), is("2048")); } @Test public void testSetCtxSizeZeroUsesModelDefault() { ModelParameters p = new ModelParameters().setCtxSize(0); - assertEquals("0", p.parameters.get("--ctx-size")); + assertThat(p.parameters.get("--ctx-size"), is("0")); } @Test public void testSetBatchSize() { ModelParameters p = new ModelParameters().setBatchSize(512); - assertEquals("512", p.parameters.get("--batch-size")); + assertThat(p.parameters.get("--batch-size"), is("512")); } @Test public void testSetUbatchSize() { ModelParameters p = new ModelParameters().setUbatchSize(256); - assertEquals("256", p.parameters.get("--ubatch-size")); + assertThat(p.parameters.get("--ubatch-size"), is("256")); } @Test public void testSetPredict() { ModelParameters p = new ModelParameters().setPredict(100); - assertEquals("100", p.parameters.get("--predict")); + assertThat(p.parameters.get("--predict"), is("100")); } @Test public void testSetPredictInfinity() { ModelParameters p = new ModelParameters().setPredict(-1); - assertEquals("-1", p.parameters.get("--predict")); + assertThat(p.parameters.get("--predict"), is("-1")); } @Test public void testSetPredictFillContext() { ModelParameters p = new ModelParameters().setPredict(-2); - assertEquals("-2", p.parameters.get("--predict")); + assertThat(p.parameters.get("--predict"), is("-2")); } @Test public void testSetKeep() { ModelParameters p = new ModelParameters().setKeep(64); - assertEquals("64", p.parameters.get("--keep")); + assertThat(p.parameters.get("--keep"), is("64")); } @Test public void testSetKeepAll() { ModelParameters p = new ModelParameters().setKeep(-1); - assertEquals("-1", p.parameters.get("--keep")); + assertThat(p.parameters.get("--keep"), is("-1")); } // ------------------------------------------------------------------------- @@ -89,13 +105,13 @@ public void testSetKeepAll() { @Test public void testSetThreads() { ModelParameters p = new ModelParameters().setThreads(8); - assertEquals("8", p.parameters.get("--threads")); + assertThat(p.parameters.get("--threads"), is("8")); } @Test public void testSetThreadsBatch() { ModelParameters p = new ModelParameters().setThreadsBatch(4); - assertEquals("4", p.parameters.get("--threads-batch")); + assertThat(p.parameters.get("--threads-batch"), is("4")); } // ------------------------------------------------------------------------- @@ -105,49 +121,49 @@ public void testSetThreadsBatch() { @Test public void testSetCpuMask() { ModelParameters p = new ModelParameters().setCpuMask("ff"); - assertEquals("ff", p.parameters.get("--cpu-mask")); + assertThat(p.parameters.get("--cpu-mask"), is("ff")); } @Test public void testSetCpuRange() { ModelParameters p = new ModelParameters().setCpuRange("0-3"); - assertEquals("0-3", p.parameters.get("--cpu-range")); + assertThat(p.parameters.get("--cpu-range"), is("0-3")); } @Test public void testSetCpuStrict() { ModelParameters p = new ModelParameters().setCpuStrict(1); - assertEquals("1", p.parameters.get("--cpu-strict")); + assertThat(p.parameters.get("--cpu-strict"), is("1")); } @Test public void testSetPoll() { ModelParameters p = new ModelParameters().setPoll(50); - assertEquals("50", p.parameters.get("--poll")); + assertThat(p.parameters.get("--poll"), is("50")); } @Test public void testSetCpuMaskBatch() { ModelParameters p = new ModelParameters().setCpuMaskBatch("0f"); - assertEquals("0f", p.parameters.get("--cpu-mask-batch")); + assertThat(p.parameters.get("--cpu-mask-batch"), is("0f")); } @Test public void testSetCpuRangeBatch() { ModelParameters p = new ModelParameters().setCpuRangeBatch("4-7"); - assertEquals("4-7", p.parameters.get("--cpu-range-batch")); + assertThat(p.parameters.get("--cpu-range-batch"), is("4-7")); } @Test public void testSetCpuStrictBatch() { ModelParameters p = new ModelParameters().setCpuStrictBatch(0); - assertEquals("0", p.parameters.get("--cpu-strict-batch")); + assertThat(p.parameters.get("--cpu-strict-batch"), is("0")); } @Test public void testSetPollBatch() { ModelParameters p = new ModelParameters().setPollBatch(100); - assertEquals("100", p.parameters.get("--poll-batch")); + assertThat(p.parameters.get("--poll-batch"), is("100")); } // ------------------------------------------------------------------------- @@ -157,79 +173,79 @@ public void testSetPollBatch() { @Test public void testSetTemp() { ModelParameters p = new ModelParameters().setTemp(0.7f); - assertEquals("0.7", p.parameters.get("--temp")); + assertThat(p.parameters.get("--temp"), is("0.7")); } @Test public void testSetTopK() { ModelParameters p = new ModelParameters().setTopK(50); - assertEquals("50", p.parameters.get("--top-k")); + assertThat(p.parameters.get("--top-k"), is("50")); } @Test public void testSetTopKDisabled() { ModelParameters p = new ModelParameters().setTopK(0); - assertEquals("0", p.parameters.get("--top-k")); + assertThat(p.parameters.get("--top-k"), is("0")); } @Test public void testSetTopP() { ModelParameters p = new ModelParameters().setTopP(0.9f); - assertEquals("0.9", p.parameters.get("--top-p")); + assertThat(p.parameters.get("--top-p"), is("0.9")); } @Test public void testSetMinP() { ModelParameters p = new ModelParameters().setMinP(0.1f); - assertEquals("0.1", p.parameters.get("--min-p")); + assertThat(p.parameters.get("--min-p"), is("0.1")); } @Test public void testSetTypical() { ModelParameters p = new ModelParameters().setTypical(0.95f); - assertEquals("0.95", p.parameters.get("--typical")); + assertThat(p.parameters.get("--typical"), is("0.95")); } @Test public void testSetRepeatPenalty() { ModelParameters p = new ModelParameters().setRepeatPenalty(1.1f); - assertEquals("1.1", p.parameters.get("--repeat-penalty")); + assertThat(p.parameters.get("--repeat-penalty"), is("1.1")); } @Test public void testSetPresencePenalty() { ModelParameters p = new ModelParameters().setPresencePenalty(0.5f); - assertEquals("0.5", p.parameters.get("--presence-penalty")); + assertThat(p.parameters.get("--presence-penalty"), is("0.5")); } @Test public void testSetFrequencyPenalty() { ModelParameters p = new ModelParameters().setFrequencyPenalty(0.3f); - assertEquals("0.3", p.parameters.get("--frequency-penalty")); + assertThat(p.parameters.get("--frequency-penalty"), is("0.3")); } @Test public void testSetMirostatLR() { ModelParameters p = new ModelParameters().setMirostatLR(0.2f); - assertEquals("0.2", p.parameters.get("--mirostat-lr")); + assertThat(p.parameters.get("--mirostat-lr"), is("0.2")); } @Test public void testSetMirostatEnt() { ModelParameters p = new ModelParameters().setMirostatEnt(4.0f); - assertEquals("4.0", p.parameters.get("--mirostat-ent")); + assertThat(p.parameters.get("--mirostat-ent"), is("4.0")); } @Test public void testSetDynatempRange() { ModelParameters p = new ModelParameters().setDynatempRange(0.5f); - assertEquals("0.5", p.parameters.get("--dynatemp-range")); + assertThat(p.parameters.get("--dynatemp-range"), is("0.5")); } @Test public void testSetDynatempExponent() { ModelParameters p = new ModelParameters().setDynatempExponent(2.0f); - assertEquals("2.0", p.parameters.get("--dynatemp-exp")); + assertThat(p.parameters.get("--dynatemp-exp"), is("2.0")); } // ------------------------------------------------------------------------- @@ -239,25 +255,25 @@ public void testSetDynatempExponent() { @Test public void testSetXtcProbability() { ModelParameters p = new ModelParameters().setXtcProbability(0.5f); - assertEquals("0.5", p.parameters.get("--xtc-probability")); + assertThat(p.parameters.get("--xtc-probability"), is("0.5")); } @Test public void testSetXtcProbabilityDisabled() { ModelParameters p = new ModelParameters().setXtcProbability(0.0f); - assertEquals("0.0", p.parameters.get("--xtc-probability")); + assertThat(p.parameters.get("--xtc-probability"), is("0.0")); } @Test public void testSetXtcThreshold() { ModelParameters p = new ModelParameters().setXtcThreshold(0.2f); - assertEquals("0.2", p.parameters.get("--xtc-threshold")); + assertThat(p.parameters.get("--xtc-threshold"), is("0.2")); } @Test public void testSetXtcThresholdDisabled() { ModelParameters p = new ModelParameters().setXtcThreshold(1.0f); - assertEquals("1.0", p.parameters.get("--xtc-threshold")); + assertThat(p.parameters.get("--xtc-threshold"), is("1.0")); } // ------------------------------------------------------------------------- @@ -267,31 +283,31 @@ public void testSetXtcThresholdDisabled() { @Test public void testSetDryMultiplier() { ModelParameters p = new ModelParameters().setDryMultiplier(0.8f); - assertEquals("0.8", p.parameters.get("--dry-multiplier")); + assertThat(p.parameters.get("--dry-multiplier"), is("0.8")); } @Test public void testSetDryMultiplierDisabled() { ModelParameters p = new ModelParameters().setDryMultiplier(0.0f); - assertEquals("0.0", p.parameters.get("--dry-multiplier")); + assertThat(p.parameters.get("--dry-multiplier"), is("0.0")); } @Test public void testSetDryBase() { ModelParameters p = new ModelParameters().setDryBase(2.0f); - assertEquals("2.0", p.parameters.get("--dry-base")); + assertThat(p.parameters.get("--dry-base"), is("2.0")); } @Test public void testSetDryAllowedLength() { ModelParameters p = new ModelParameters().setDryAllowedLength(3); - assertEquals("3", p.parameters.get("--dry-allowed-length")); + assertThat(p.parameters.get("--dry-allowed-length"), is("3")); } @Test public void testSetDrySequenceBreaker() { ModelParameters p = new ModelParameters().setDrySequenceBreaker("\\n"); - assertEquals("\\n", p.parameters.get("--dry-sequence-breaker")); + assertThat(p.parameters.get("--dry-sequence-breaker"), is("\\n")); } // ------------------------------------------------------------------------- @@ -301,19 +317,19 @@ public void testSetDrySequenceBreaker() { @Test public void testSetRopeScale() { ModelParameters p = new ModelParameters().setRopeScale(2.0f); - assertEquals("2.0", p.parameters.get("--rope-scale")); + assertThat(p.parameters.get("--rope-scale"), is("2.0")); } @Test public void testSetRopeFreqBase() { ModelParameters p = new ModelParameters().setRopeFreqBase(10000.0f); - assertEquals("10000.0", p.parameters.get("--rope-freq-base")); + assertThat(p.parameters.get("--rope-freq-base"), is("10000.0")); } @Test public void testSetRopeFreqScale() { ModelParameters p = new ModelParameters().setRopeFreqScale(0.5f); - assertEquals("0.5", p.parameters.get("--rope-freq-scale")); + assertThat(p.parameters.get("--rope-freq-scale"), is("0.5")); } // ------------------------------------------------------------------------- @@ -323,31 +339,31 @@ public void testSetRopeFreqScale() { @Test public void testSetYarnOrigCtx() { ModelParameters p = new ModelParameters().setYarnOrigCtx(4096); - assertEquals("4096", p.parameters.get("--yarn-orig-ctx")); + assertThat(p.parameters.get("--yarn-orig-ctx"), is("4096")); } @Test public void testSetYarnExtFactor() { ModelParameters p = new ModelParameters().setYarnExtFactor(0.5f); - assertEquals("0.5", p.parameters.get("--yarn-ext-factor")); + assertThat(p.parameters.get("--yarn-ext-factor"), is("0.5")); } @Test public void testSetYarnAttnFactor() { ModelParameters p = new ModelParameters().setYarnAttnFactor(1.5f); - assertEquals("1.5", p.parameters.get("--yarn-attn-factor")); + assertThat(p.parameters.get("--yarn-attn-factor"), is("1.5")); } @Test public void testSetYarnBetaSlow() { ModelParameters p = new ModelParameters().setYarnBetaSlow(2.0f); - assertEquals("2.0", p.parameters.get("--yarn-beta-slow")); + assertThat(p.parameters.get("--yarn-beta-slow"), is("2.0")); } @Test public void testSetYarnBetaFast() { ModelParameters p = new ModelParameters().setYarnBetaFast(16.0f); - assertEquals("16.0", p.parameters.get("--yarn-beta-fast")); + assertThat(p.parameters.get("--yarn-beta-fast"), is("16.0")); } // ------------------------------------------------------------------------- @@ -357,13 +373,13 @@ public void testSetYarnBetaFast() { @Test public void testSetGrpAttnN() { ModelParameters p = new ModelParameters().setGrpAttnN(4); - assertEquals("4", p.parameters.get("--grp-attn-n")); + assertThat(p.parameters.get("--grp-attn-n"), is("4")); } @Test public void testSetGrpAttnW() { ModelParameters p = new ModelParameters().setGrpAttnW(1024); - assertEquals("1024", p.parameters.get("--grp-attn-w")); + assertThat(p.parameters.get("--grp-attn-w"), is("1024")); } // ------------------------------------------------------------------------- @@ -374,7 +390,7 @@ public void testSetGrpAttnW() { public void testSetCacheTypeKAllValues() { for (CacheType ct : CacheType.values()) { ModelParameters p = new ModelParameters().setCacheTypeK(ct); - assertEquals(ct.name().toLowerCase(), p.parameters.get("--cache-type-k")); + assertThat(p.parameters.get("--cache-type-k"), is(ct.name().toLowerCase())); } } @@ -382,112 +398,112 @@ public void testSetCacheTypeKAllValues() { public void testSetCacheTypeVAllValues() { for (CacheType ct : CacheType.values()) { ModelParameters p = new ModelParameters().setCacheTypeV(ct); - assertEquals(ct.name().toLowerCase(), p.parameters.get("--cache-type-v")); + assertThat(p.parameters.get("--cache-type-v"), is(ct.name().toLowerCase())); } } @Test public void testSetDefragThold() { ModelParameters p = new ModelParameters().setDefragThold(0.2f); - assertEquals("0.2", p.parameters.get("--defrag-thold")); + assertThat(p.parameters.get("--defrag-thold"), is("0.2")); } @Test public void testSetDefragTholdDisabled() { ModelParameters p = new ModelParameters().setDefragThold(-1.0f); - assertEquals("-1.0", p.parameters.get("--defrag-thold")); + assertThat(p.parameters.get("--defrag-thold"), is("-1.0")); } @Test public void testDisableKvOffload() { ModelParameters p = new ModelParameters().disableKvOffload(); - assertTrue(p.parameters.containsKey("--no-kv-offload")); - assertNull(p.parameters.get("--no-kv-offload")); + assertThat(p.parameters, hasKey("--no-kv-offload")); + assertThat(p.parameters.get("--no-kv-offload"), is(nullValue())); } @Test public void testEnableDumpKvCache() { ModelParameters p = new ModelParameters().enableDumpKvCache(); - assertTrue(p.parameters.containsKey("--dump-kv-cache")); - assertNull(p.parameters.get("--dump-kv-cache")); + assertThat(p.parameters, hasKey("--dump-kv-cache")); + assertThat(p.parameters.get("--dump-kv-cache"), is(nullValue())); } @Test public void testSetKvUnifiedTrue() { ModelParameters p = new ModelParameters().setKvUnified(true); - assertTrue(p.parameters.containsKey("--kv-unified")); - assertNull(p.parameters.get("--kv-unified")); - assertFalse(p.parameters.containsKey("--no-kv-unified")); + assertThat(p.parameters, hasKey("--kv-unified")); + assertThat(p.parameters.get("--kv-unified"), is(nullValue())); + assertThat(p.parameters, not(hasKey("--no-kv-unified"))); } @Test public void testSetKvUnifiedFalse() { ModelParameters p = new ModelParameters().setKvUnified(false); - assertTrue(p.parameters.containsKey("--no-kv-unified")); - assertNull(p.parameters.get("--no-kv-unified")); - assertFalse(p.parameters.containsKey("--kv-unified")); + assertThat(p.parameters, hasKey("--no-kv-unified")); + assertThat(p.parameters.get("--no-kv-unified"), is(nullValue())); + assertThat(p.parameters, not(hasKey("--kv-unified"))); } @Test public void testSetKvUnifiedFlipFromTrueToFalse() { ModelParameters p = new ModelParameters().setKvUnified(true).setKvUnified(false); - assertTrue(p.parameters.containsKey("--no-kv-unified")); - assertFalse(p.parameters.containsKey("--kv-unified")); + assertThat(p.parameters, hasKey("--no-kv-unified")); + assertThat(p.parameters, not(hasKey("--kv-unified"))); } @Test public void testSetKvUnifiedFlipFromFalseToTrue() { ModelParameters p = new ModelParameters().setKvUnified(false).setKvUnified(true); - assertTrue(p.parameters.containsKey("--kv-unified")); - assertFalse(p.parameters.containsKey("--no-kv-unified")); + assertThat(p.parameters, hasKey("--kv-unified")); + assertThat(p.parameters, not(hasKey("--no-kv-unified"))); } @Test public void testSetCacheRamMib() { ModelParameters p = new ModelParameters().setCacheRamMib(4096); - assertEquals("4096", p.parameters.get("--cache-ram")); + assertThat(p.parameters.get("--cache-ram"), is("4096")); } @Test public void testSetCacheRamMibUnlimited() { ModelParameters p = new ModelParameters().setCacheRamMib(-1); - assertEquals("-1", p.parameters.get("--cache-ram")); + assertThat(p.parameters.get("--cache-ram"), is("-1")); } @Test public void testSetCacheRamMibDisabled() { ModelParameters p = new ModelParameters().setCacheRamMib(0); - assertEquals("0", p.parameters.get("--cache-ram")); + assertThat(p.parameters.get("--cache-ram"), is("0")); } @Test public void testSetClearIdleTrue() { ModelParameters p = new ModelParameters().setClearIdle(true); - assertTrue(p.parameters.containsKey("--cache-idle-slots")); - assertNull(p.parameters.get("--cache-idle-slots")); - assertFalse(p.parameters.containsKey("--no-cache-idle-slots")); + assertThat(p.parameters, hasKey("--cache-idle-slots")); + assertThat(p.parameters.get("--cache-idle-slots"), is(nullValue())); + assertThat(p.parameters, not(hasKey("--no-cache-idle-slots"))); } @Test public void testSetClearIdleFalse() { ModelParameters p = new ModelParameters().setClearIdle(false); - assertTrue(p.parameters.containsKey("--no-cache-idle-slots")); - assertNull(p.parameters.get("--no-cache-idle-slots")); - assertFalse(p.parameters.containsKey("--cache-idle-slots")); + assertThat(p.parameters, hasKey("--no-cache-idle-slots")); + assertThat(p.parameters.get("--no-cache-idle-slots"), is(nullValue())); + assertThat(p.parameters, not(hasKey("--cache-idle-slots"))); } @Test public void testSetClearIdleFlipFromTrueToFalse() { ModelParameters p = new ModelParameters().setClearIdle(true).setClearIdle(false); - assertTrue(p.parameters.containsKey("--no-cache-idle-slots")); - assertFalse(p.parameters.containsKey("--cache-idle-slots")); + assertThat(p.parameters, hasKey("--no-cache-idle-slots")); + assertThat(p.parameters, not(hasKey("--cache-idle-slots"))); } @Test public void testSetClearIdleFlipFromFalseToTrue() { ModelParameters p = new ModelParameters().setClearIdle(false).setClearIdle(true); - assertTrue(p.parameters.containsKey("--cache-idle-slots")); - assertFalse(p.parameters.containsKey("--no-cache-idle-slots")); + assertThat(p.parameters, hasKey("--cache-idle-slots")); + assertThat(p.parameters, not(hasKey("--no-cache-idle-slots"))); } @Test @@ -495,30 +511,30 @@ public void testKvUnifiedCacheRamClearIdleChaining() { // All three features wired together as they would be in production use ModelParameters p = new ModelParameters().setKvUnified(true).setCacheRamMib(8192).setClearIdle(true); - assertTrue(p.parameters.containsKey("--kv-unified")); - assertEquals("8192", p.parameters.get("--cache-ram")); - assertTrue(p.parameters.containsKey("--cache-idle-slots")); + assertThat(p.parameters, hasKey("--kv-unified")); + assertThat(p.parameters.get("--cache-ram"), is("8192")); + assertThat(p.parameters, hasKey("--cache-idle-slots")); // Opposite flags must be absent - assertFalse(p.parameters.containsKey("--no-kv-unified")); - assertFalse(p.parameters.containsKey("--no-cache-idle-slots")); + assertThat(p.parameters, not(hasKey("--no-kv-unified"))); + assertThat(p.parameters, not(hasKey("--no-cache-idle-slots"))); } @Test public void testSetKvUnifiedReturnsSameInstance() { ModelParameters p = new ModelParameters(); - assertSame(p.setKvUnified(true), p); + assertThat(p.setKvUnified(true), is(sameInstance(p))); } @Test public void testSetCacheRamMibReturnsSameInstance() { ModelParameters p = new ModelParameters(); - assertSame(p.setCacheRamMib(4096), p); + assertThat(p.setCacheRamMib(4096), is(sameInstance(p))); } @Test public void testSetClearIdleReturnsSameInstance() { ModelParameters p = new ModelParameters(); - assertSame(p.setClearIdle(true), p); + assertThat(p.setClearIdle(true), is(sameInstance(p))); } // ------------------------------------------------------------------------- @@ -528,33 +544,33 @@ public void testSetClearIdleReturnsSameInstance() { @Test public void testSetGpuLayers() { ModelParameters p = new ModelParameters().setGpuLayers(32); - assertEquals("32", p.parameters.get("--gpu-layers")); + assertThat(p.parameters.get("--gpu-layers"), is("32")); } @Test public void testSetSplitModeAllValues() { for (GpuSplitMode mode : GpuSplitMode.values()) { ModelParameters p = new ModelParameters().setSplitMode(mode); - assertEquals(mode.name().toLowerCase(), p.parameters.get("--split-mode")); + assertThat(p.parameters.get("--split-mode"), is(mode.name().toLowerCase())); } } @Test public void testSetTensorSplit() { ModelParameters p = new ModelParameters().setTensorSplit("0.5,0.5"); - assertEquals("0.5,0.5", p.parameters.get("--tensor-split")); + assertThat(p.parameters.get("--tensor-split"), is("0.5,0.5")); } @Test public void testSetMainGpu() { ModelParameters p = new ModelParameters().setMainGpu(1); - assertEquals("1", p.parameters.get("--main-gpu")); + assertThat(p.parameters.get("--main-gpu"), is("1")); } @Test public void testSetDevices() { ModelParameters p = new ModelParameters().setDevices("cuda0,cuda1"); - assertEquals("cuda0,cuda1", p.parameters.get("--device")); + assertThat(p.parameters.get("--device"), is("cuda0,cuda1")); } // ------------------------------------------------------------------------- @@ -564,22 +580,22 @@ public void testSetDevices() { @Test public void testEnableMlock() { ModelParameters p = new ModelParameters().enableMlock(); - assertTrue(p.parameters.containsKey("--mlock")); - assertNull(p.parameters.get("--mlock")); + assertThat(p.parameters, hasKey("--mlock")); + assertThat(p.parameters.get("--mlock"), is(nullValue())); } @Test public void testDisableMmap() { ModelParameters p = new ModelParameters().disableMmap(); - assertTrue(p.parameters.containsKey("--no-mmap")); - assertNull(p.parameters.get("--no-mmap")); + assertThat(p.parameters, hasKey("--no-mmap")); + assertThat(p.parameters.get("--no-mmap"), is(nullValue())); } @Test public void testSetNumaAllValues() { for (NumaStrategy ns : NumaStrategy.values()) { ModelParameters p = new ModelParameters().setNuma(ns); - assertEquals(ns.name().toLowerCase(), p.parameters.get("--numa")); + assertThat(p.parameters.get("--numa"), is(ns.name().toLowerCase())); } } @@ -590,21 +606,21 @@ public void testSetNumaAllValues() { @Test public void testSetParallel() { ModelParameters p = new ModelParameters().setParallel(4); - assertEquals("4", p.parameters.get("--parallel")); + assertThat(p.parameters.get("--parallel"), is("4")); } @Test public void testEnableContBatching() { ModelParameters p = new ModelParameters().enableContBatching(); - assertTrue(p.parameters.containsKey("--cont-batching")); - assertNull(p.parameters.get("--cont-batching")); + assertThat(p.parameters, hasKey("--cont-batching")); + assertThat(p.parameters.get("--cont-batching"), is(nullValue())); } @Test public void testDisableContBatching() { ModelParameters p = new ModelParameters().disableContBatching(); - assertTrue(p.parameters.containsKey("--no-cont-batching")); - assertNull(p.parameters.get("--no-cont-batching")); + assertThat(p.parameters, hasKey("--no-cont-batching")); + assertThat(p.parameters.get("--no-cont-batching"), is(nullValue())); } // ------------------------------------------------------------------------- @@ -614,106 +630,106 @@ public void testDisableContBatching() { @Test public void testDisableContextShift() { ModelParameters p = new ModelParameters().disableContextShift(); - assertTrue(p.parameters.containsKey("--no-context-shift")); - assertNull(p.parameters.get("--no-context-shift")); + assertThat(p.parameters, hasKey("--no-context-shift")); + assertThat(p.parameters.get("--no-context-shift"), is(nullValue())); } @Test public void testEnableFlashAttn() { ModelParameters p = new ModelParameters().enableFlashAttn(); - assertTrue(p.parameters.containsKey("--flash-attn")); - assertNull(p.parameters.get("--flash-attn")); + assertThat(p.parameters, hasKey("--flash-attn")); + assertThat(p.parameters.get("--flash-attn"), is(nullValue())); } @Test public void testDisablePerf() { ModelParameters p = new ModelParameters().disablePerf(); - assertTrue(p.parameters.containsKey("--no-perf")); - assertNull(p.parameters.get("--no-perf")); + assertThat(p.parameters, hasKey("--no-perf")); + assertThat(p.parameters.get("--no-perf"), is(nullValue())); } @Test public void testEnableEscape() { ModelParameters p = new ModelParameters().enableEscape(); - assertTrue(p.parameters.containsKey("--escape")); - assertNull(p.parameters.get("--escape")); + assertThat(p.parameters, hasKey("--escape")); + assertThat(p.parameters.get("--escape"), is(nullValue())); } @Test public void testDisableEscape() { ModelParameters p = new ModelParameters().disableEscape(); - assertTrue(p.parameters.containsKey("--no-escape")); - assertNull(p.parameters.get("--no-escape")); + assertThat(p.parameters, hasKey("--no-escape")); + assertThat(p.parameters.get("--no-escape"), is(nullValue())); } @Test public void testEnableSpecial() { ModelParameters p = new ModelParameters().enableSpecial(); - assertTrue(p.parameters.containsKey("--special")); - assertNull(p.parameters.get("--special")); + assertThat(p.parameters, hasKey("--special")); + assertThat(p.parameters.get("--special"), is(nullValue())); } @Test public void testSkipWarmup() { ModelParameters p = new ModelParameters().skipWarmup(); - assertTrue(p.parameters.containsKey("--no-warmup")); - assertNull(p.parameters.get("--no-warmup")); + assertThat(p.parameters, hasKey("--no-warmup")); + assertThat(p.parameters.get("--no-warmup"), is(nullValue())); } @Test public void testSetSpmInfill() { ModelParameters p = new ModelParameters().setSpmInfill(); - assertTrue(p.parameters.containsKey("--spm-infill")); - assertNull(p.parameters.get("--spm-infill")); + assertThat(p.parameters, hasKey("--spm-infill")); + assertThat(p.parameters.get("--spm-infill"), is(nullValue())); } @Test public void testIgnoreEos() { ModelParameters p = new ModelParameters().ignoreEos(); - assertTrue(p.parameters.containsKey("--ignore-eos")); - assertNull(p.parameters.get("--ignore-eos")); + assertThat(p.parameters, hasKey("--ignore-eos")); + assertThat(p.parameters.get("--ignore-eos"), is(nullValue())); } @Test public void testEnableCheckTensors() { ModelParameters p = new ModelParameters().enableCheckTensors(); - assertTrue(p.parameters.containsKey("--check-tensors")); - assertNull(p.parameters.get("--check-tensors")); + assertThat(p.parameters, hasKey("--check-tensors")); + assertThat(p.parameters.get("--check-tensors"), is(nullValue())); } @Test public void testEnableEmbedding() { ModelParameters p = new ModelParameters().enableEmbedding(); - assertTrue(p.parameters.containsKey("--embedding")); - assertNull(p.parameters.get("--embedding")); + assertThat(p.parameters, hasKey("--embedding")); + assertThat(p.parameters.get("--embedding"), is(nullValue())); } @Test public void testEnableReranking() { ModelParameters p = new ModelParameters().enableReranking(); - assertTrue(p.parameters.containsKey("--reranking")); - assertNull(p.parameters.get("--reranking")); + assertThat(p.parameters, hasKey("--reranking")); + assertThat(p.parameters.get("--reranking"), is(nullValue())); } @Test public void testSetVocabOnly() { ModelParameters p = new ModelParameters().setVocabOnly(); - assertTrue(p.parameters.containsKey("--vocab-only")); - assertNull(p.parameters.get("--vocab-only")); + assertThat(p.parameters, hasKey("--vocab-only")); + assertThat(p.parameters.get("--vocab-only"), is(nullValue())); } @Test public void testEnableJinja() { ModelParameters p = new ModelParameters().enableJinja(); - assertTrue(p.parameters.containsKey("--jinja")); - assertNull(p.parameters.get("--jinja")); + assertThat(p.parameters, hasKey("--jinja")); + assertThat(p.parameters.get("--jinja"), is(nullValue())); } @Test public void testSetLoraInitWithoutApply() { ModelParameters p = new ModelParameters().setLoraInitWithoutApply(); - assertTrue(p.parameters.containsKey("--lora-init-without-apply")); - assertNull(p.parameters.get("--lora-init-without-apply")); + assertThat(p.parameters, hasKey("--lora-init-without-apply")); + assertThat(p.parameters.get("--lora-init-without-apply"), is(nullValue())); } // ------------------------------------------------------------------------- @@ -723,19 +739,19 @@ public void testSetLoraInitWithoutApply() { @Test public void testSetSeed() { ModelParameters p = new ModelParameters().setSeed(42); - assertEquals("42", p.parameters.get("--seed")); + assertThat(p.parameters.get("--seed"), is("42")); } @Test public void testSetSeedRandom() { ModelParameters p = new ModelParameters().setSeed(-1); - assertEquals("-1", p.parameters.get("--seed")); + assertThat(p.parameters.get("--seed"), is("-1")); } @Test public void testSetLogitBias() { ModelParameters p = new ModelParameters().setLogitBias("1+0.5"); - assertEquals("1+0.5", p.parameters.get("--logit-bias")); + assertThat(p.parameters.get("--logit-bias"), is("1+0.5")); } // ------------------------------------------------------------------------- @@ -745,19 +761,19 @@ public void testSetLogitBias() { @Test public void testSetGrammar() { ModelParameters p = new ModelParameters().setGrammar("root ::= \"hello\""); - assertEquals("root ::= \"hello\"", p.parameters.get("--grammar")); + assertThat(p.parameters.get("--grammar"), is("root ::= \"hello\"")); } @Test public void testSetGrammarFile() { ModelParameters p = new ModelParameters().setGrammarFile("grammar.gbnf"); - assertEquals("grammar.gbnf", p.parameters.get("--grammar-file")); + assertThat(p.parameters.get("--grammar-file"), is("grammar.gbnf")); } @Test public void testSetJsonSchema() { ModelParameters p = new ModelParameters().setJsonSchema("{\"type\":\"object\"}"); - assertEquals("{\"type\":\"object\"}", p.parameters.get("--json-schema")); + assertThat(p.parameters.get("--json-schema"), is("{\"type\":\"object\"}")); } // ------------------------------------------------------------------------- @@ -768,7 +784,7 @@ public void testSetJsonSchema() { public void testSetChatTemplate() { ModelParameters p = new ModelParameters().setChatTemplate("{% for msg in messages %}{{ msg.content }}{% endfor %}"); - assertEquals("{% for msg in messages %}{{ msg.content }}{% endfor %}", p.parameters.get("--chat-template")); + assertThat(p.parameters.get("--chat-template"), is("{% for msg in messages %}{{ msg.content }}{% endfor %}")); } @Test @@ -777,8 +793,8 @@ public void testSetChatTemplateKwargs() { kwargs.put("enable_thinking", "true"); ModelParameters p = new ModelParameters().setChatTemplateKwargs(kwargs); String val = p.parameters.get("--chat-template-kwargs"); - assertNotNull(val); - assertTrue(val.contains("\"enable_thinking\":true")); + assertThat(val, is(notNullValue())); + assertThat(val, containsString("\"enable_thinking\":true")); } @Test @@ -788,11 +804,11 @@ public void testSetChatTemplateKwargsMultiple() { kwargs.put("key2", "42"); ModelParameters p = new ModelParameters().setChatTemplateKwargs(kwargs); String val = p.parameters.get("--chat-template-kwargs"); - assertNotNull(val); - assertTrue(val.startsWith("{")); - assertTrue(val.endsWith("}")); - assertTrue(val.contains("\"key1\":\"val1\"")); - assertTrue(val.contains("\"key2\":42")); + assertThat(val, is(notNullValue())); + assertThat(val, startsWith("{")); + assertThat(val, endsWith("}")); + assertThat(val, containsString("\"key1\":\"val1\"")); + assertThat(val, containsString("\"key2\":42")); } // ------------------------------------------------------------------------- @@ -802,43 +818,43 @@ public void testSetChatTemplateKwargsMultiple() { @Test public void testSetModel() { ModelParameters p = new ModelParameters().setModel("/path/to/model.gguf"); - assertEquals("/path/to/model.gguf", p.parameters.get("--model")); + assertThat(p.parameters.get("--model"), is("/path/to/model.gguf")); } @Test public void testSetModelUrl() { ModelParameters p = new ModelParameters().setModelUrl("https://example.com/model.gguf"); - assertEquals("https://example.com/model.gguf", p.parameters.get("--model-url")); + assertThat(p.parameters.get("--model-url"), is("https://example.com/model.gguf")); } @Test public void testSetHfRepo() { ModelParameters p = new ModelParameters().setHfRepo("meta-llama/Llama-2-7b"); - assertEquals("meta-llama/Llama-2-7b", p.parameters.get("--hf-repo")); + assertThat(p.parameters.get("--hf-repo"), is("meta-llama/Llama-2-7b")); } @Test public void testSetHfFile() { ModelParameters p = new ModelParameters().setHfFile("model-q4.gguf"); - assertEquals("model-q4.gguf", p.parameters.get("--hf-file")); + assertThat(p.parameters.get("--hf-file"), is("model-q4.gguf")); } @Test public void testSetHfToken() { ModelParameters p = new ModelParameters().setHfToken("hf_abc123"); - assertEquals("hf_abc123", p.parameters.get("--hf-token")); + assertThat(p.parameters.get("--hf-token"), is("hf_abc123")); } @Test public void testSetHfRepoV() { ModelParameters p = new ModelParameters().setHfRepoV("org/vocoder"); - assertEquals("org/vocoder", p.parameters.get("--hf-repo-v")); + assertThat(p.parameters.get("--hf-repo-v"), is("org/vocoder")); } @Test public void testSetHfFileV() { ModelParameters p = new ModelParameters().setHfFileV("vocoder.gguf"); - assertEquals("vocoder.gguf", p.parameters.get("--hf-file-v")); + assertThat(p.parameters.get("--hf-file-v"), is("vocoder.gguf")); } // ------------------------------------------------------------------------- @@ -848,19 +864,19 @@ public void testSetHfFileV() { @Test public void testSetCacheReuse() { ModelParameters p = new ModelParameters().setCacheReuse(128); - assertEquals("128", p.parameters.get("--cache-reuse")); + assertThat(p.parameters.get("--cache-reuse"), is("128")); } @Test public void testSetSlotSavePath() { ModelParameters p = new ModelParameters().setSlotSavePath("/tmp/slots"); - assertEquals("/tmp/slots", p.parameters.get("--slot-save-path")); + assertThat(p.parameters.get("--slot-save-path"), is("/tmp/slots")); } @Test public void testSetSlotPromptSimilarity() { ModelParameters p = new ModelParameters().setSlotPromptSimilarity(0.8f); - assertEquals("0.8", p.parameters.get("--slot-prompt-similarity")); + assertThat(p.parameters.get("--slot-prompt-similarity"), is("0.8")); } // ------------------------------------------------------------------------- @@ -870,7 +886,7 @@ public void testSetSlotPromptSimilarity() { @Test public void testSetOverrideKv() { ModelParameters p = new ModelParameters().setOverrideKv("tokenizer.ggml.pre=spm"); - assertEquals("tokenizer.ggml.pre=spm", p.parameters.get("--override-kv")); + assertThat(p.parameters.get("--override-kv"), is("tokenizer.ggml.pre=spm")); } // ------------------------------------------------------------------------- @@ -880,13 +896,13 @@ public void testSetOverrideKv() { @Test public void testAddLoraAdapter() { ModelParameters p = new ModelParameters().addLoraAdapter("adapter.bin"); - assertEquals("adapter.bin", p.parameters.get("--lora")); + assertThat(p.parameters.get("--lora"), is("adapter.bin")); } @Test public void testAddControlVector() { ModelParameters p = new ModelParameters().addControlVector("vec.bin"); - assertEquals("vec.bin", p.parameters.get("--control-vector")); + assertThat(p.parameters.get("--control-vector"), is("vec.bin")); } // ------------------------------------------------------------------------- @@ -896,19 +912,19 @@ public void testAddControlVector() { @Test public void testSetModelDraft() { ModelParameters p = new ModelParameters().setModelDraft("/path/to/draft.gguf"); - assertEquals("/path/to/draft.gguf", p.parameters.get("--spec-draft-model")); + assertThat(p.parameters.get("--spec-draft-model"), is("/path/to/draft.gguf")); } @Test public void testSetDeviceDraft() { ModelParameters p = new ModelParameters().setDeviceDraft("cuda0"); - assertEquals("cuda0", p.parameters.get("--spec-draft-device")); + assertThat(p.parameters.get("--spec-draft-device"), is("cuda0")); } @Test public void testSetGpuLayersDraft() { ModelParameters p = new ModelParameters().setGpuLayersDraft(16); - assertEquals("16", p.parameters.get("--spec-draft-ngl")); + assertThat(p.parameters.get("--spec-draft-ngl"), is("16")); } @Test @@ -916,8 +932,8 @@ public void testSetDraftMax() { // Regression: --draft-max was REMOVED in b9016 and now throws std::invalid_argument // at model load. Must use --spec-draft-n-max. ModelParameters p = new ModelParameters().setDraftMax(8); - assertEquals("8", p.parameters.get("--spec-draft-n-max")); - assertFalse(p.parameters.containsKey("--draft-max"), "--draft-max throws on b9016+; must not appear in args"); + assertThat(p.parameters.get("--spec-draft-n-max"), is("8")); + assertThat("--draft-max throws on b9016+; must not appear in args", p.parameters, not(hasKey("--draft-max"))); } @Test @@ -925,14 +941,14 @@ public void testSetDraftMin() { // Regression: --draft-min was REMOVED in b9016 and now throws std::invalid_argument // at model load. Must use --spec-draft-n-min. ModelParameters p = new ModelParameters().setDraftMin(2); - assertEquals("2", p.parameters.get("--spec-draft-n-min")); - assertFalse(p.parameters.containsKey("--draft-min"), "--draft-min throws on b9016+; must not appear in args"); + assertThat(p.parameters.get("--spec-draft-n-min"), is("2")); + assertThat("--draft-min throws on b9016+; must not appear in args", p.parameters, not(hasKey("--draft-min"))); } @Test public void testSetDraftPMin() { ModelParameters p = new ModelParameters().setDraftPMin(0.5f); - assertEquals("0.5", p.parameters.get("--spec-draft-p-min")); + assertThat(p.parameters.get("--spec-draft-p-min"), is("0.5")); } // ------------------------------------------------------------------------- @@ -942,41 +958,41 @@ public void testSetDraftPMin() { @Test public void testDisableLog() { ModelParameters p = new ModelParameters().disableLog(); - assertTrue(p.parameters.containsKey("--log-disable")); - assertNull(p.parameters.get("--log-disable")); + assertThat(p.parameters, hasKey("--log-disable")); + assertThat(p.parameters.get("--log-disable"), is(nullValue())); } @Test public void testSetLogFile() { ModelParameters p = new ModelParameters().setLogFile("/tmp/llama.log"); - assertEquals("/tmp/llama.log", p.parameters.get("--log-file")); + assertThat(p.parameters.get("--log-file"), is("/tmp/llama.log")); } @Test public void testSetVerbose() { ModelParameters p = new ModelParameters().setVerbose(); - assertTrue(p.parameters.containsKey("--verbose")); - assertNull(p.parameters.get("--verbose")); + assertThat(p.parameters, hasKey("--verbose")); + assertThat(p.parameters.get("--verbose"), is(nullValue())); } @Test public void testSetLogVerbosity() { ModelParameters p = new ModelParameters().setLogVerbosity(3); - assertEquals("3", p.parameters.get("--log-verbosity")); + assertThat(p.parameters.get("--log-verbosity"), is("3")); } @Test public void testEnableLogPrefix() { ModelParameters p = new ModelParameters().enableLogPrefix(); - assertTrue(p.parameters.containsKey("--log-prefix")); - assertNull(p.parameters.get("--log-prefix")); + assertThat(p.parameters, hasKey("--log-prefix")); + assertThat(p.parameters.get("--log-prefix"), is(nullValue())); } @Test public void testEnableLogTimestamps() { ModelParameters p = new ModelParameters().enableLogTimestamps(); - assertTrue(p.parameters.containsKey("--log-timestamps")); - assertNull(p.parameters.get("--log-timestamps")); + assertThat(p.parameters, hasKey("--log-timestamps")); + assertThat(p.parameters.get("--log-timestamps"), is(nullValue())); } // ------------------------------------------------------------------------- @@ -986,13 +1002,13 @@ public void testEnableLogTimestamps() { @Test public void testSetFitTrue() { ModelParameters p = new ModelParameters().setFit(true); - assertEquals(ModelParameters.FIT_ON, p.parameters.get("--fit")); + assertThat(p.parameters.get("--fit"), is(ModelParameters.FIT_ON)); } @Test public void testSetFitFalse() { ModelParameters p = new ModelParameters().setFit(false); - assertEquals(ModelParameters.FIT_OFF, p.parameters.get("--fit")); + assertThat(p.parameters.get("--fit"), is(ModelParameters.FIT_OFF)); } // ------------------------------------------------------------------------- @@ -1003,7 +1019,7 @@ public void testSetFitFalse() { public void testSetRopeScalingAllValues() { for (RopeScalingType type : RopeScalingType.values()) { ModelParameters p = new ModelParameters().setRopeScaling(type); - assertEquals(type.getArgValue(), p.parameters.get("--rope-scaling")); + assertThat(p.parameters.get("--rope-scaling"), is(type.getArgValue())); } } @@ -1015,7 +1031,9 @@ public void testSetRopeScalingAllValues() { public void testSetMirostatAllValues() { for (MiroStat m : MiroStat.values()) { ModelParameters p = new ModelParameters().setMirostat(m); - assertEquals(String.valueOf(m.ordinal()), p.parameters.get("--mirostat")); + // Assert against the enum's CLI arg-value contract (what setMirostat + // actually writes), not Enum.ordinal() (Error Prone EnumOrdinal). + assertThat(p.parameters.get("--mirostat"), is(m.getArgValue())); } } @@ -1026,18 +1044,18 @@ public void testSetMirostatAllValues() { @Test public void testExtendedChainingReturnsSameInstance() { ModelParameters p = new ModelParameters(); - assertSame(p.setCtxSize(2048), p); - assertSame(p.setBatchSize(512), p); - assertSame(p.setTemp(0.7f), p); - assertSame(p.setTopK(50), p); - assertSame(p.setDryMultiplier(0.5f), p); - assertSame(p.setXtcProbability(0.3f), p); - assertSame(p.setRopeScale(2.0f), p); - assertSame(p.setGpuLayers(32), p); - assertSame(p.enableFlashAttn(), p); - assertSame(p.disableContextShift(), p); - assertSame(p.setModelDraft("/draft.gguf"), p); - assertSame(p.disableLog(), p); + assertThat(p.setCtxSize(2048), is(sameInstance(p))); + assertThat(p.setBatchSize(512), is(sameInstance(p))); + assertThat(p.setTemp(0.7f), is(sameInstance(p))); + assertThat(p.setTopK(50), is(sameInstance(p))); + assertThat(p.setDryMultiplier(0.5f), is(sameInstance(p))); + assertThat(p.setXtcProbability(0.3f), is(sameInstance(p))); + assertThat(p.setRopeScale(2.0f), is(sameInstance(p))); + assertThat(p.setGpuLayers(32), is(sameInstance(p))); + assertThat(p.enableFlashAttn(), is(sameInstance(p))); + assertThat(p.disableContextShift(), is(sameInstance(p))); + assertThat(p.setModelDraft("/draft.gguf"), is(sameInstance(p))); + assertThat(p.disableLog(), is(sameInstance(p))); } // ------------------------------------------------------------------------- @@ -1053,8 +1071,8 @@ public void testToArrayComplexCombination() { .enableFlashAttn(); String[] arr = p.toArray(); // argv[0]="" + --fit + on + --model + model.gguf + --ctx-size + 2048 + --embedding + --flash-attn = 9 - assertEquals(9, arr.length); - assertEquals("", arr[0]); + assertThat(arr, arrayWithSize(9)); + assertThat(arr[0], is("")); } // ------------------------------------------------------------------------- @@ -1064,16 +1082,16 @@ public void testToArrayComplexCombination() { @Test public void testIsDefaultForCtxSize() { ModelParameters p = new ModelParameters(); - assertTrue(p.isUnset("ctx-size")); + assertThat(p.isUnset("ctx-size"), is(true)); p.setCtxSize(2048); - assertFalse(p.isUnset("ctx-size")); + assertThat(p.isUnset("ctx-size"), is(false)); } @Test public void testIsDefaultForFlagOnly() { ModelParameters p = new ModelParameters(); - assertTrue(p.isUnset("flash-attn")); + assertThat(p.isUnset("flash-attn"), is(true)); p.enableFlashAttn(); - assertFalse(p.isUnset("flash-attn")); + assertThat(p.isUnset("flash-attn"), is(false)); } } diff --git a/src/test/java/net/ladenthin/llama/ModelParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/ModelParametersTest.java similarity index 71% rename from src/test/java/net/ladenthin/llama/ModelParametersTest.java rename to src/test/java/net/ladenthin/llama/parameters/ModelParametersTest.java index 80bccb93..55bc38e6 100644 --- a/src/test/java/net/ladenthin/llama/ModelParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/ModelParametersTest.java @@ -3,12 +3,21 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.*; +package net.ladenthin.llama.parameters; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.Arrays; import java.util.List; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.args.CacheType; import net.ladenthin.llama.args.GpuSplitMode; import net.ladenthin.llama.args.MiroStat; @@ -35,13 +44,13 @@ public class ModelParametersTest { @Test public void testSetPriorityValid0() { ModelParameters p = new ModelParameters().setPriority(0); - assertEquals("0", p.parameters.get("--prio")); + assertThat(p.parameters.get("--prio"), is("0")); } @Test public void testSetPriorityValid3() { ModelParameters p = new ModelParameters().setPriority(3); - assertEquals("3", p.parameters.get("--prio")); + assertThat(p.parameters.get("--prio"), is("3")); } @Test @@ -61,7 +70,7 @@ public void testSetPriorityTooHigh() { @Test public void testSetPriorityBatchValid1() { ModelParameters p = new ModelParameters().setPriorityBatch(1); - assertEquals("1", p.parameters.get("--prio-batch")); + assertThat(p.parameters.get("--prio-batch"), is("1")); } @Test @@ -81,19 +90,19 @@ public void testSetPriorityBatchTooHigh() { @Test public void testSetRepeatLastNValidZero() { ModelParameters p = new ModelParameters().setRepeatLastN(0); - assertEquals("0", p.parameters.get("--repeat-last-n")); + assertThat(p.parameters.get("--repeat-last-n"), is("0")); } @Test public void testSetRepeatLastNValidMinusOne() { ModelParameters p = new ModelParameters().setRepeatLastN(-1); - assertEquals("-1", p.parameters.get("--repeat-last-n")); + assertThat(p.parameters.get("--repeat-last-n"), is("-1")); } @Test public void testSetRepeatLastNValid64() { ModelParameters p = new ModelParameters().setRepeatLastN(64); - assertEquals("64", p.parameters.get("--repeat-last-n")); + assertThat(p.parameters.get("--repeat-last-n"), is("64")); } @Test @@ -108,13 +117,13 @@ public void testSetRepeatLastNTooLow() { @Test public void testSetDryPenaltyLastNValidMinusOne() { ModelParameters p = new ModelParameters().setDryPenaltyLastN(-1); - assertEquals("-1", p.parameters.get("--dry-penalty-last-n")); + assertThat(p.parameters.get("--dry-penalty-last-n"), is("-1")); } @Test public void testSetDryPenaltyLastNValidZero() { ModelParameters p = new ModelParameters().setDryPenaltyLastN(0); - assertEquals("0", p.parameters.get("--dry-penalty-last-n")); + assertThat(p.parameters.get("--dry-penalty-last-n"), is("0")); } @Test @@ -129,26 +138,26 @@ public void testSetDryPenaltyLastNTooLow() { @Test public void testSetSamplersSingle() { ModelParameters p = new ModelParameters().setSamplers(Sampler.TOP_K); - assertEquals("top_k", p.parameters.get("--samplers")); + assertThat(p.parameters.get("--samplers"), is("top_k")); } @Test public void testSetSamplersMultiple() { ModelParameters p = new ModelParameters().setSamplers(Sampler.TOP_K, Sampler.TOP_P, Sampler.TEMPERATURE); - assertEquals("top_k;top_p;temperature", p.parameters.get("--samplers")); + assertThat(p.parameters.get("--samplers"), is("top_k;top_p;temperature")); } @Test public void testSetSamplersEmpty() { ModelParameters p = new ModelParameters().setSamplers(); - assertFalse(p.parameters.containsKey("--samplers")); + assertThat(p.parameters, not(hasKey("--samplers"))); } @Test public void testSetSamplersAllLowercase() { for (Sampler s : Sampler.values()) { ModelParameters p = new ModelParameters().setSamplers(s); - assertEquals(s.name().toLowerCase(), p.parameters.get("--samplers")); + assertThat(p.parameters.get("--samplers"), is(s.name().toLowerCase())); } } @@ -159,13 +168,13 @@ public void testSetSamplersAllLowercase() { @Test public void testAddLoraScaledAdapter() { ModelParameters p = new ModelParameters().addLoraScaledAdapter("adapter.bin", 0.5f); - assertEquals("adapter.bin,0.5", p.parameters.get("--lora-scaled")); + assertThat(p.parameters.get("--lora-scaled"), is("adapter.bin,0.5")); } @Test public void testAddControlVectorScaled() { ModelParameters p = new ModelParameters().addControlVectorScaled("vec.bin", 1.5f); - assertEquals("vec.bin,1.5", p.parameters.get("--control-vector-scaled")); + assertThat(p.parameters.get("--control-vector-scaled"), is("vec.bin,1.5")); } // ------------------------------------------------------------------------- @@ -175,13 +184,13 @@ public void testAddControlVectorScaled() { @Test public void testSetControlVectorLayerRange() { ModelParameters p = new ModelParameters().setControlVectorLayerRange(2, 10); - assertEquals("2,10", p.parameters.get("--control-vector-layer-range")); + assertThat(p.parameters.get("--control-vector-layer-range"), is("2,10")); } @Test public void testSetControlVectorLayerRangeSameStartEnd() { ModelParameters p = new ModelParameters().setControlVectorLayerRange(5, 5); - assertEquals("5,5", p.parameters.get("--control-vector-layer-range")); + assertThat(p.parameters.get("--control-vector-layer-range"), is("5,5")); } // ------------------------------------------------------------------------- @@ -191,19 +200,19 @@ public void testSetControlVectorLayerRangeSameStartEnd() { @Test public void testIsDefaultTrueWhenNotSet() { ModelParameters p = new ModelParameters(); - assertTrue(p.isUnset("threads")); + assertThat(p.isUnset("threads"), is(true)); } @Test public void testIsDefaultFalseWhenSet() { ModelParameters p = new ModelParameters().setThreads(4); - assertFalse(p.isUnset("threads")); + assertThat(p.isUnset("threads"), is(false)); } @Test public void testIsDefaultFalseAfterFlagOnly() { ModelParameters p = new ModelParameters().enableEmbedding(); - assertFalse(p.isUnset("embedding")); + assertThat(p.isUnset("embedding"), is(false)); } // ------------------------------------------------------------------------- @@ -213,85 +222,86 @@ public void testIsDefaultFalseAfterFlagOnly() { @Test public void testSetPoolingTypeMean() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.MEAN); - assertEquals(PoolingType.MEAN.getArgValue(), p.parameters.get(ModelParameters.ARG_POOLING)); + assertThat(p.parameters.get(ModelParameters.ARG_POOLING), is(PoolingType.MEAN.getArgValue())); } @Test public void testSetPoolingTypeNone() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.NONE); - assertEquals(PoolingType.NONE.getArgValue(), p.parameters.get(ModelParameters.ARG_POOLING)); + assertThat(p.parameters.get(ModelParameters.ARG_POOLING), is(PoolingType.NONE.getArgValue())); } @Test public void testSetPoolingTypeCls() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.CLS); - assertEquals(PoolingType.CLS.getArgValue(), p.parameters.get(ModelParameters.ARG_POOLING)); + assertThat(p.parameters.get(ModelParameters.ARG_POOLING), is(PoolingType.CLS.getArgValue())); } @Test public void testSetPoolingTypeLast() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.LAST); - assertEquals(PoolingType.LAST.getArgValue(), p.parameters.get(ModelParameters.ARG_POOLING)); + assertThat(p.parameters.get(ModelParameters.ARG_POOLING), is(PoolingType.LAST.getArgValue())); } @Test public void testSetPoolingTypeRank() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.RANK); - assertEquals(PoolingType.RANK.getArgValue(), p.parameters.get(ModelParameters.ARG_POOLING)); + assertThat(p.parameters.get(ModelParameters.ARG_POOLING), is(PoolingType.RANK.getArgValue())); } @Test public void testSetPoolingTypeUnspecifiedDoesNotSetParam() { ModelParameters p = new ModelParameters().setPoolingType(PoolingType.UNSPECIFIED); - assertFalse( - p.parameters.containsKey(ModelParameters.ARG_POOLING), - "UNSPECIFIED pooling type must not add " + ModelParameters.ARG_POOLING + " to parameters"); + assertThat( + "UNSPECIFIED pooling type must not add " + ModelParameters.ARG_POOLING + " to parameters", + p.parameters, + not(hasKey(ModelParameters.ARG_POOLING))); } @Test public void testSetPoolingTypeUnspecifiedLeavesDefaultUntouched() { // A fresh ModelParameters must not have ARG_POOLING set by default either ModelParameters fresh = new ModelParameters(); - assertFalse(fresh.parameters.containsKey(ModelParameters.ARG_POOLING)); + assertThat(fresh.parameters, not(hasKey(ModelParameters.ARG_POOLING))); // Calling setPoolingType(UNSPECIFIED) must leave that invariant intact fresh.setPoolingType(PoolingType.UNSPECIFIED); - assertFalse(fresh.parameters.containsKey(ModelParameters.ARG_POOLING)); + assertThat(fresh.parameters, not(hasKey(ModelParameters.ARG_POOLING))); } @Test public void testSetRopeScaling() { ModelParameters p = new ModelParameters().setRopeScaling(RopeScalingType.YARN2); - assertEquals("yarn", p.parameters.get("--rope-scaling")); + assertThat(p.parameters.get("--rope-scaling"), is("yarn")); } @Test public void testSetCacheTypeKLowercase() { ModelParameters p = new ModelParameters().setCacheTypeK(CacheType.F16); - assertEquals("f16", p.parameters.get("--cache-type-k")); + assertThat(p.parameters.get("--cache-type-k"), is("f16")); } @Test public void testSetCacheTypeVLowercase() { ModelParameters p = new ModelParameters().setCacheTypeV(CacheType.Q8_0); - assertEquals("q8_0", p.parameters.get("--cache-type-v")); + assertThat(p.parameters.get("--cache-type-v"), is("q8_0")); } @Test public void testSetSplitModeLowercase() { ModelParameters p = new ModelParameters().setSplitMode(GpuSplitMode.LAYER); - assertEquals("layer", p.parameters.get("--split-mode")); + assertThat(p.parameters.get("--split-mode"), is("layer")); } @Test public void testSetNumaLowercase() { ModelParameters p = new ModelParameters().setNuma(NumaStrategy.DISTRIBUTE); - assertEquals("distribute", p.parameters.get("--numa")); + assertThat(p.parameters.get("--numa"), is("distribute")); } @Test public void testSetMirostatOrdinal() { ModelParameters p = new ModelParameters().setMirostat(MiroStat.V2); - assertEquals("2", p.parameters.get("--mirostat")); + assertThat(p.parameters.get("--mirostat"), is("2")); } // ------------------------------------------------------------------------- @@ -301,35 +311,35 @@ public void testSetMirostatOrdinal() { @Test public void testToStringContainsKey() { ModelParameters p = new ModelParameters().setThreads(4); - assertTrue(p.toString().contains("--threads")); - assertTrue(p.toString().contains("4")); + assertThat(p.toString(), containsString("--threads")); + assertThat(p.toString(), containsString("4")); } @Test public void testToStringFlagOnlyNoValue() { ModelParameters p = new ModelParameters().enableEmbedding(); String s = p.toString(); - assertTrue(s.contains("--embedding")); + assertThat(s, containsString("--embedding")); // Flag-only: value is null, so no "null" text should appear - assertFalse(s.contains("null")); + assertThat(s, not(containsString("null"))); } @Test public void testFitValueTrueReturnsFitOn() { - assertEquals(ModelParameters.FIT_ON, ModelParameters.fitValue(true)); + assertThat(ModelParameters.fitValue(true), is(ModelParameters.FIT_ON)); } @Test public void testFitValueFalseReturnsFitOff() { - assertEquals(ModelParameters.FIT_OFF, ModelParameters.fitValue(false)); + assertThat(ModelParameters.fitValue(false), is(ModelParameters.FIT_OFF)); } @Test public void testToStringDefaultContainsFit() { ModelParameters p = new ModelParameters(); String s = p.toString(); - assertTrue(s.contains("--fit")); - assertTrue(s.contains(ModelParameters.DEFAULT_FIT_VALUE)); + assertThat(s, containsString("--fit")); + assertThat(s, containsString(ModelParameters.DEFAULT_FIT_VALUE)); } // ------------------------------------------------------------------------- @@ -341,11 +351,11 @@ public void testToArrayDefaultParametersHasFit() { // toArray() = ["", "--fit", DEFAULT_FIT_VALUE] ModelParameters p = new ModelParameters(); String[] arr = p.toArray(); - assertEquals(3, arr.length); - assertEquals("", arr[0]); + assertThat(arr, arrayWithSize(3)); + assertThat(arr[0], is("")); List list = Arrays.asList(arr); - assertTrue(list.contains("--fit")); - assertTrue(list.contains(ModelParameters.DEFAULT_FIT_VALUE)); + assertThat(list, hasItem("--fit")); + assertThat(list, hasItem(ModelParameters.DEFAULT_FIT_VALUE)); } @Test @@ -353,13 +363,13 @@ public void testToArrayScalarParameterHasFiveElements() { // argv[0]="" + "--fit" + DEFAULT_FIT_VALUE + "--threads" + "4" = 5 ModelParameters p = new ModelParameters().setThreads(4); String[] arr = p.toArray(); - assertEquals(5, arr.length); - assertEquals("", arr[0]); + assertThat(arr, arrayWithSize(5)); + assertThat(arr[0], is("")); List list = Arrays.asList(arr); - assertTrue(list.contains("--threads")); - assertTrue(list.contains("4")); - assertTrue(list.contains("--fit")); - assertTrue(list.contains(ModelParameters.DEFAULT_FIT_VALUE)); + assertThat(list, hasItem("--threads")); + assertThat(list, hasItem("4")); + assertThat(list, hasItem("--fit")); + assertThat(list, hasItem(ModelParameters.DEFAULT_FIT_VALUE)); } @Test @@ -367,12 +377,12 @@ public void testToArrayFlagOnlyHasFourElements() { // argv[0]="" + "--fit" + DEFAULT_FIT_VALUE + "--embedding" (no value) = 4 ModelParameters p = new ModelParameters().enableEmbedding(); String[] arr = p.toArray(); - assertEquals(4, arr.length); - assertEquals("", arr[0]); + assertThat(arr, arrayWithSize(4)); + assertThat(arr[0], is("")); List list = Arrays.asList(arr); - assertTrue(list.contains("--embedding")); - assertTrue(list.contains("--fit")); - assertTrue(list.contains(ModelParameters.DEFAULT_FIT_VALUE)); + assertThat(list, hasItem("--embedding")); + assertThat(list, hasItem("--fit")); + assertThat(list, hasItem(ModelParameters.DEFAULT_FIT_VALUE)); } @Test @@ -380,14 +390,14 @@ public void testToArrayMultipleParameters() { ModelParameters p = new ModelParameters().setThreads(4).enableEmbedding(); String[] arr = p.toArray(); // 1 (argv[0]) + 2 (--fit DEFAULT_FIT_VALUE) + 2 (--threads 4) + 1 (--embedding) = 6 - assertEquals(6, arr.length); - assertEquals("", arr[0]); + assertThat(arr, arrayWithSize(6)); + assertThat(arr[0], is("")); List list = Arrays.asList(arr); - assertTrue(list.contains("--threads")); - assertTrue(list.contains("4")); - assertTrue(list.contains("--embedding")); - assertTrue(list.contains("--fit")); - assertTrue(list.contains(ModelParameters.DEFAULT_FIT_VALUE)); + assertThat(list, hasItem("--threads")); + assertThat(list, hasItem("4")); + assertThat(list, hasItem("--embedding")); + assertThat(list, hasItem("--fit")); + assertThat(list, hasItem(ModelParameters.DEFAULT_FIT_VALUE)); } // ------------------------------------------------------------------------- @@ -397,9 +407,9 @@ public void testToArrayMultipleParameters() { @Test public void testBuilderChainingReturnsSameInstance() { ModelParameters p = new ModelParameters(); - assertSame(p.setThreads(4), p); - assertSame(p.setGpuLayers(10), p); - assertSame(p.enableEmbedding(), p); + assertThat(p.setThreads(4), is(sameInstance(p))); + assertThat(p.setGpuLayers(10), is(sameInstance(p))); + assertThat(p.enableEmbedding(), is(sameInstance(p))); } // ------------------------------------------------------------------------- @@ -409,25 +419,25 @@ public void testBuilderChainingReturnsSameInstance() { @Test public void testSetMmproj() { ModelParameters p = new ModelParameters().setMmproj("/models/mmproj.gguf"); - assertEquals("/models/mmproj.gguf", p.parameters.get("--mmproj")); + assertThat(p.parameters.get("--mmproj"), is("/models/mmproj.gguf")); } @Test public void testSetMmprojUrl() { ModelParameters p = new ModelParameters().setMmprojUrl("https://example.com/mmproj.gguf"); - assertEquals("https://example.com/mmproj.gguf", p.parameters.get("--mmproj-url")); + assertThat(p.parameters.get("--mmproj-url"), is("https://example.com/mmproj.gguf")); } @Test public void testEnableMmprojAuto() { ModelParameters p = new ModelParameters().enableMmprojAuto(); - assertTrue(p.parameters.containsKey("--mmproj-auto")); + assertThat(p.parameters, hasKey("--mmproj-auto")); } @Test public void testEnableMmprojOffload() { ModelParameters p = new ModelParameters().enableMmprojOffload(); - assertTrue(p.parameters.containsKey("--mmproj-offload")); + assertThat(p.parameters, hasKey("--mmproj-offload")); } // ------------------------------------------------------------------------- @@ -437,38 +447,38 @@ public void testEnableMmprojOffload() { @Test public void testSetReasoningFormatNone() { ModelParameters p = new ModelParameters().setReasoningFormat(net.ladenthin.llama.args.ReasoningFormat.NONE); - assertEquals("none", p.parameters.get("--reasoning-format")); + assertThat(p.parameters.get("--reasoning-format"), is("none")); } @Test public void testSetReasoningFormatAuto() { ModelParameters p = new ModelParameters().setReasoningFormat(net.ladenthin.llama.args.ReasoningFormat.AUTO); - assertEquals("auto", p.parameters.get("--reasoning-format")); + assertThat(p.parameters.get("--reasoning-format"), is("auto")); } @Test public void testSetReasoningFormatDeepseek() { ModelParameters p = new ModelParameters().setReasoningFormat(net.ladenthin.llama.args.ReasoningFormat.DEEPSEEK); - assertEquals("deepseek", p.parameters.get("--reasoning-format")); + assertThat(p.parameters.get("--reasoning-format"), is("deepseek")); } @Test public void testSetReasoningFormatDeepseekLegacy() { ModelParameters p = new ModelParameters().setReasoningFormat(net.ladenthin.llama.args.ReasoningFormat.DEEPSEEK_LEGACY); - assertEquals("deepseek-legacy", p.parameters.get("--reasoning-format")); + assertThat(p.parameters.get("--reasoning-format"), is("deepseek-legacy")); } @Test public void testSetReasoningBudgetPositive() { ModelParameters p = new ModelParameters().setReasoningBudget(1024); - assertEquals("1024", p.parameters.get("--reasoning-budget")); + assertThat(p.parameters.get("--reasoning-budget"), is("1024")); } @Test public void testSetReasoningBudgetDisabled() { ModelParameters p = new ModelParameters().setReasoningBudget(-1); - assertEquals("-1", p.parameters.get("--reasoning-budget")); + assertThat(p.parameters.get("--reasoning-budget"), is("-1")); } // ------------------------------------------------------------------------- @@ -478,13 +488,13 @@ public void testSetReasoningBudgetDisabled() { @Test public void testSetSleepIdleSeconds() { ModelParameters p = new ModelParameters().setSleepIdleSeconds(60); - assertEquals("60", p.parameters.get("--sleep-idle-seconds")); + assertThat(p.parameters.get("--sleep-idle-seconds"), is("60")); } @Test public void testSetSleepIdleSecondsZero() { ModelParameters p = new ModelParameters().setSleepIdleSeconds(0); - assertEquals("0", p.parameters.get("--sleep-idle-seconds")); + assertThat(p.parameters.get("--sleep-idle-seconds"), is("0")); } // ------------------------------------------------------------------------- @@ -494,14 +504,14 @@ public void testSetSleepIdleSecondsZero() { @Test public void testSetClearIdleTrue_usesCacheIdleSlotsFlag() { ModelParameters p = new ModelParameters().setClearIdle(true); - assertTrue(p.parameters.containsKey("--cache-idle-slots")); - assertFalse(p.parameters.containsKey("--no-cache-idle-slots")); + assertThat(p.parameters, hasKey("--cache-idle-slots")); + assertThat(p.parameters, not(hasKey("--no-cache-idle-slots"))); } @Test public void testSetClearIdleFalse_usesNoCacheIdleSlotsFlag() { ModelParameters p = new ModelParameters().setClearIdle(false); - assertTrue(p.parameters.containsKey("--no-cache-idle-slots")); - assertFalse(p.parameters.containsKey("--cache-idle-slots")); + assertThat(p.parameters, hasKey("--no-cache-idle-slots")); + assertThat(p.parameters, not(hasKey("--cache-idle-slots"))); } } diff --git a/src/test/java/net/ladenthin/llama/value/ChatChoiceTest.java b/src/test/java/net/ladenthin/llama/value/ChatChoiceTest.java new file mode 100644 index 00000000..91dd66d8 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/value/ChatChoiceTest.java @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.value; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +@ClaudeGenerated( + purpose = "Pin every ChatChoice accessor to a distinct non-default value so the index/message/" + + "finishReason getters and the value-semantics of equals/hashCode/toString are all " + + "mutation-covered.") +public class ChatChoiceTest { + + private static ChatChoice choice(int index, String role, String content, String finish) { + return new ChatChoice(index, new ChatMessage(role, content), finish); + } + + @Test + public void accessorsReturnConstructorValues() { + ChatMessage msg = new ChatMessage("assistant", "hello"); + ChatChoice c = new ChatChoice(7, msg, "stop"); + // index getter — a non-zero value kills the "return 0" primitive mutant. + assertThat(c.getIndex(), is(7)); + assertThat(c.getMessage(), is(sameInstance(msg))); + assertThat(c.getFinishReason(), is("stop")); + } + + @Test + public void toStringRendersAllFields() { + ChatChoice c = choice(3, "assistant", "hi there", "length"); + String s = c.toString(); + assertThat(s, containsString("3")); + assertThat(s, containsString("length")); + assertThat(s, containsString("hi there")); + } + + @Test + public void equalsAndHashCodeAreValueBased() { + ChatChoice a = choice(1, "assistant", "x", "stop"); + ChatChoice b = choice(1, "assistant", "x", "stop"); + assertThat(a, is(b)); + assertThat(a.hashCode(), is(b.hashCode())); + } + + @Test + public void differingIndexBreaksEquality() { + assertThat(choice(1, "assistant", "x", "stop"), is(not(choice(2, "assistant", "x", "stop")))); + } + + @Test + public void differingFinishReasonBreaksEquality() { + assertThat(choice(1, "assistant", "x", "stop"), is(not(choice(1, "assistant", "x", "length")))); + } +} diff --git a/src/test/java/net/ladenthin/llama/value/ChatMessageTest.java b/src/test/java/net/ladenthin/llama/value/ChatMessageTest.java new file mode 100644 index 00000000..3a1e33a2 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/value/ChatMessageTest.java @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.value; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +@ClaudeGenerated( + purpose = "Pin every ChatMessage path: plain/tool/multimodal constructors and factories, the " + + "concatText text-joining (newline-joined, image parts skipped, no leading newline), the " + + "parts validation helpers (null/empty rejection), the three toString branches " + + "(plain, tool_calls, tool_call_id), and value-based equals/hashCode — full mutation coverage.") +public class ChatMessageTest { + + @Test + public void plainMessageAccessors() { + ChatMessage m = new ChatMessage("user", "hi"); + assertThat(m.getRole(), is("user")); + assertThat(m.getContent(), is("hi")); + assertThat(m.hasParts(), is(false)); + assertThat(m.getParts().isPresent(), is(false)); + assertThat(m.getToolCalls(), is(empty())); + assertThat(m.getToolCallId().isPresent(), is(false)); + } + + @Test + public void toStringPlainBranch() { + assertThat(new ChatMessage("assistant", "hello").toString(), is("assistant: hello")); + } + + @Test + public void toStringToolCallsBranch() { + ChatMessage m = + ChatMessage.assistantToolCalls("thinking", Collections.singletonList(new ToolCall("c1", "f", "{}"))); + assertThat(m.toString(), is("assistant (tool_calls=1): thinking")); + } + + @Test + public void toStringToolCallIdBranch() { + ChatMessage m = ChatMessage.toolResult("c1", "42"); + assertThat(m.getRole(), is("tool")); + assertThat(m.getToolCallId().orElseThrow(), is("c1")); + assertThat(m.toString(), is("tool (tool_call_id=c1): 42")); + } + + @Test + public void assistantToolCallsNullContentBecomesEmpty() { + // L144 ternary: content == null ? "" : content + ChatMessage m = ChatMessage.assistantToolCalls(null, Collections.singletonList(new ToolCall("c1", "f", "{}"))); + assertThat(m.getContent(), is("")); + assertThat(m.getToolCalls(), hasSize(1)); + } + + @Test + public void assistantToolCallsKeepsNonNullContent() { + ChatMessage m = + ChatMessage.assistantToolCalls("reason", Collections.singletonList(new ToolCall("c1", "f", "{}"))); + assertThat(m.getContent(), is("reason")); + } + + @Test + public void multimodalConcatenatesTextPartsSkippingImagesNoLeadingNewline() { + // concatText: text parts newline-joined, image parts skipped, first part not prefixed with '\n'. + ChatMessage m = new ChatMessage( + "user", + Arrays.asList( + ContentPart.text("describe"), + ContentPart.imageUrl("data:image/png;base64,X"), + ContentPart.text("please"))); + assertThat(m.getContent(), is("describe\nplease")); + assertThat(m.hasParts(), is(true)); + assertThat(m.getParts().orElseThrow(), hasSize(3)); + } + + @Test + public void userMultimodalFactoryBuildsUserMessageWithParts() { + // L155: factory must return a real instance (not null) carrying the parts. + ChatMessage m = + ChatMessage.userMultimodal(ContentPart.text("a"), ContentPart.imageUrl("data:image/png;base64,Y")); + assertThat(m.getRole(), is("user")); + assertThat(m.hasParts(), is(true)); + assertThat(m.getParts().orElseThrow(), hasSize(2)); + } + + @Test + public void nullPartsRejected() { + assertThrows(IllegalArgumentException.class, () -> new ChatMessage("user", (List) null)); + } + + @Test + public void emptyPartsRejected() { + assertThrows( + IllegalArgumentException.class, () -> new ChatMessage("user", Collections.emptyList())); + } + + @Test + public void equalsAndHashCodeAreValueBased() { + assertThat(new ChatMessage("user", "hi"), is(new ChatMessage("user", "hi"))); + assertThat(new ChatMessage("user", "hi").hashCode(), is(new ChatMessage("user", "hi").hashCode())); + } + + @Test + public void differingContentBreaksEquality() { + assertThat(new ChatMessage("user", "hi"), is(not(new ChatMessage("user", "bye")))); + } +} diff --git a/src/test/java/net/ladenthin/llama/ChatResponseTest.java b/src/test/java/net/ladenthin/llama/value/ChatResponseTest.java similarity index 61% rename from src/test/java/net/ladenthin/llama/ChatResponseTest.java rename to src/test/java/net/ladenthin/llama/value/ChatResponseTest.java index b35611c3..def6c9cc 100644 --- a/src/test/java/net/ladenthin/llama/ChatResponseTest.java +++ b/src/test/java/net/ladenthin/llama/value/ChatResponseTest.java @@ -2,13 +2,19 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.json.ChatResponseParser; +import net.ladenthin.llama.parameters.ChatRequest; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -29,24 +35,24 @@ public void parsesPlainAssistantReply() { + "\"predicted_n\":5,\"predicted_ms\":50.0,\"predicted_per_second\":100.0}}"; ChatResponse r = parser.parseResponse(json); - assertEquals("chatcmpl-1", r.getId()); - assertEquals(1, r.getChoices().size()); + assertThat(r.getId(), is("chatcmpl-1")); + assertThat(r.getChoices(), hasSize(1)); ChatChoice c = r.getChoices().get(0); - assertEquals(0, c.getIndex()); - assertEquals("assistant", c.getMessage().getRole()); - assertEquals("Hello!", c.getMessage().getContent()); - assertEquals("stop", c.getFinishReason()); - assertTrue(c.getMessage().getToolCalls().isEmpty()); + assertThat(c.getIndex(), is(0)); + assertThat(c.getMessage().getRole(), is("assistant")); + assertThat(c.getMessage().getContent(), is("Hello!")); + assertThat(c.getFinishReason(), is("stop")); + assertThat(c.getMessage().getToolCalls(), is(empty())); - assertEquals(12L, r.getUsage().getPromptTokens()); - assertEquals(5L, r.getUsage().getCompletionTokens()); - assertEquals(17L, r.getUsage().getTotalTokens()); + assertThat(r.getUsage().getPromptTokens(), is(12L)); + assertThat(r.getUsage().getCompletionTokens(), is(5L)); + assertThat(r.getUsage().getTotalTokens(), is(17L)); - assertEquals(12, r.getTimings().getPromptN()); + assertThat(r.getTimings().getPromptN(), is(12)); assertEquals(100.0, r.getTimings().getPromptMs(), 1e-9); assertEquals(100.0, r.getTimings().getPredictedPerSecond(), 1e-9); - assertEquals("Hello!", r.getFirstContent()); + assertThat(r.getFirstContent(), is("Hello!")); } @Test @@ -61,14 +67,14 @@ public void parsesToolCalls() { + "\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":7}}"; ChatResponse r = parser.parseResponse(json); ChatMessage m = r.getFirstMessage().orElseThrow(); - assertEquals("assistant", m.getRole()); + assertThat(m.getRole(), is("assistant")); List tc = m.getToolCalls(); - assertEquals(2, tc.size()); - assertEquals("call_a", tc.get(0).getId()); - assertEquals("get_weather", tc.get(0).getName()); - assertEquals("{\"city\":\"Berlin\"}", tc.get(0).getArgumentsJson()); - assertEquals("get_time", tc.get(1).getName()); - assertEquals("tool_calls", r.getChoices().get(0).getFinishReason()); + assertThat(tc, hasSize(2)); + assertThat(tc.get(0).getId(), is("call_a")); + assertThat(tc.get(0).getName(), is("get_weather")); + assertThat(tc.get(0).getArgumentsJson(), is("{\"city\":\"Berlin\"}")); + assertThat(tc.get(1).getName(), is("get_time")); + assertThat(r.getChoices().get(0).getFinishReason(), is("tool_calls")); } @Test @@ -81,16 +87,24 @@ public void parsesObjectShapedArguments() { ChatResponse r = parser.parseResponse(json); String args = r.getFirstMessage().orElseThrow().getToolCalls().get(0).getArgumentsJson(); // exact text isn't guaranteed, but must contain both fields - assertTrue(args.contains("\"a\":1"), "expected serialized object, got: " + args); - assertTrue(args.contains("\"b\":2")); + assertThat("expected serialized object, got: " + args, args, containsString("\"a\":1")); + assertThat(args, containsString("\"b\":2")); } @Test public void malformedInputYieldsEmptyResponse() { ChatResponse r = parser.parseResponse("{not json"); - assertEquals("", r.getId()); - assertTrue(r.getChoices().isEmpty()); - assertEquals(0L, r.getUsage().getTotalTokens()); + assertThat(r.getId(), is("")); + assertThat(r.getChoices(), is(empty())); + assertThat(r.getUsage().getTotalTokens(), is(0L)); + } + + @Test + public void rawJsonIsPreserved() { + String json = "{\"id\":\"chatcmpl-raw\",\"choices\":[]}"; + ChatResponse r = parser.parseResponse(json); + // Assert on content (not just non-null) so the empty-string return mutant is killed. + assertThat(r.getRawJson(), containsString("chatcmpl-raw")); } @Test @@ -103,15 +117,15 @@ public void buildMessagesJsonRoundTripsToolTurns() { .appendMessage(ChatMessage.toolResult("c1", "4")); String msgs = req.buildMessagesJson(); - assertTrue(msgs.contains("\"tool_calls\""), msgs); - assertTrue(msgs.contains("\"tool_call_id\":\"c1\""), msgs); - assertTrue(msgs.contains("\"name\":\"add\""), msgs); + assertThat(msgs, msgs, containsString("\"tool_calls\"")); + assertThat(msgs, msgs, containsString("\"tool_call_id\":\"c1\"")); + assertThat(msgs, msgs, containsString("\"name\":\"add\"")); } @Test public void buildToolsJsonEmptyWhenNoTools() { ChatRequest req = ChatRequest.empty().appendMessage("user", "hi"); - assertTrue(req.buildToolsJson().isEmpty()); + assertThat(req.buildToolsJson().isPresent(), is(false)); } @Test @@ -120,8 +134,8 @@ public void buildToolsJsonInlinesParameterSchema() { .appendTool(new ToolDefinition( "echo", "Echo a string", "{\"type\":\"object\",\"properties\":{\"s\":{\"type\":\"string\"}}}")); String tools = req.buildToolsJson().orElseThrow(); - assertTrue(tools.contains("\"type\":\"function\""), tools); - assertTrue(tools.contains("\"name\":\"echo\""), tools); - assertTrue(tools.contains("\"properties\""), tools); + assertThat(tools, tools, containsString("\"type\":\"function\"")); + assertThat(tools, tools, containsString("\"name\":\"echo\"")); + assertThat(tools, tools, containsString("\"properties\"")); } } diff --git a/src/test/java/net/ladenthin/llama/ChatTranscriptTest.java b/src/test/java/net/ladenthin/llama/value/ChatTranscriptTest.java similarity index 74% rename from src/test/java/net/ladenthin/llama/ChatTranscriptTest.java rename to src/test/java/net/ladenthin/llama/value/ChatTranscriptTest.java index b9600bbd..7051c6a4 100644 --- a/src/test/java/net/ladenthin/llama/ChatTranscriptTest.java +++ b/src/test/java/net/ladenthin/llama/value/ChatTranscriptTest.java @@ -2,15 +2,21 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertNull; +package net.ladenthin.llama.value; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import net.ladenthin.llama.Session; +import net.ladenthin.llama.exception.LlamaException; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -51,9 +57,7 @@ private static void simulateSend(ChatTranscript t, String userMessage, String as // Phase 1: build wire-format (model would see this). List> wire = t.messagesWithPendingUserTurn(userMessage); // The wire format must contain the pending turn the model is about to answer. - assertTrue( - wire.stream().anyMatch(p -> "user".equals(p.getKey()) && userMessage.equals(p.getValue())), - "wire-format must carry the pending user turn"); + assertThat("wire-format must carry the pending user turn", wire, hasItem(new Pair<>("user", userMessage))); // Phase 2: model returned successfully — commit both turns atomically. t.appendRound(userMessage, assistantReply); } @@ -83,13 +87,13 @@ void appendRoundCommitsBothTurnsAtomically() { t.appendRound("hi", "hello back"); - assertEquals(2, t.size()); + assertThat(t.size(), is(2)); List snapshot = t.snapshot(); - assertEquals(2, snapshot.size()); - assertEquals("user", snapshot.get(0).getRole()); - assertEquals("hi", snapshot.get(0).getContent()); - assertEquals("assistant", snapshot.get(1).getRole()); - assertEquals("hello back", snapshot.get(1).getContent()); + assertThat(snapshot, hasSize(2)); + assertThat(snapshot.get(0).getRole(), is("user")); + assertThat(snapshot.get(0).getContent(), is("hi")); + assertThat(snapshot.get(1).getRole(), is("assistant")); + assertThat(snapshot.get(1).getContent(), is("hello back")); } @Test @@ -102,7 +106,7 @@ void appendUserAndAssistantSeparatelyMatchAppendRound() { b.appendUserTurn("hi"); b.appendAssistantTurn("hello back"); - assertEquals(a.snapshot(), b.snapshot(), "atomic-round and split-commit must converge"); + assertThat("atomic-round and split-commit must converge", b.snapshot(), is(a.snapshot())); } @Test @@ -116,13 +120,13 @@ void messagesWithPendingUserTurnDoesNotMutate() { List> wire = t.messagesWithPendingUserTurn("pending"); // Build a wire-format containing committed turns + pending user. - assertEquals(3, wire.size(), "1 user + 1 assistant + 1 pending user"); - assertEquals("user", wire.get(2).getKey()); - assertEquals("pending", wire.get(2).getValue()); + assertThat("1 user + 1 assistant + 1 pending user", wire, hasSize(3)); + assertThat(wire.get(2).getKey(), is("user")); + assertThat(wire.get(2).getValue(), is("pending")); // The transcript itself MUST be unchanged. - assertEquals(sizeBefore, t.size(), "transcript size unchanged"); - assertEquals(snapshotBefore, t.snapshot(), "transcript snapshot unchanged"); + assertThat("transcript size unchanged", t.size(), is(sizeBefore)); + assertThat("transcript snapshot unchanged", t.snapshot(), is(snapshotBefore)); } @Test @@ -131,10 +135,10 @@ void messagesWithPendingUserTurnReturnsFreshList() { ChatTranscript t = new ChatTranscript(null); List> first = t.messagesWithPendingUserTurn("hi"); List> second = t.messagesWithPendingUserTurn("hi"); - assertNotSame( + assertThat( + "each wire-format build returns a fresh list — callers may mutate without affecting peers", first, - second, - "each wire-format build returns a fresh list — callers may mutate without affecting peers"); + is(not(sameInstance(second)))); } @Test @@ -145,16 +149,16 @@ void snapshotIncludesSystemMessage() { List snap = t.snapshot(); - assertEquals(3, snap.size()); - assertEquals("system", snap.get(0).getRole()); - assertEquals("you are an assistant", snap.get(0).getContent()); + assertThat(snap, hasSize(3)); + assertThat(snap.get(0).getRole(), is("system")); + assertThat(snap.get(0).getContent(), is("you are an assistant")); } @Test @DisplayName("snapshot omits system message when null or empty") void snapshotOmitsSystemMessageWhenAbsent() { - assertEquals(0, new ChatTranscript(null).snapshot().size()); - assertEquals(0, new ChatTranscript("").snapshot().size()); + assertThat(new ChatTranscript(null).snapshot(), is(empty())); + assertThat(new ChatTranscript("").snapshot(), is(empty())); } @Test @@ -169,7 +173,7 @@ void snapshotIsUnmodifiable() { @Test @DisplayName("getSystemMessage returns null when absent") void getSystemMessageNullWhenAbsent() { - assertNull(new ChatTranscript(null).getSystemMessage()); + assertThat(new ChatTranscript(null).getSystemMessage(), is(nullValue())); } } @@ -181,7 +185,7 @@ class TwoPhaseCommit { @DisplayName("simulated model failure leaves a FRESH transcript untouched") void freshTranscriptUntouchedWhenModelThrows() { ChatTranscript t = new ChatTranscript("system"); - assertEquals(0, t.size(), "precondition: fresh transcript has no turns"); + assertThat("precondition: fresh transcript has no turns", t.size(), is(0)); int snapshotSizeBefore = t.snapshot().size(); // Caller simulates Session.send where the model rejects the request. @@ -192,11 +196,9 @@ void freshTranscriptUntouchedWhenModelThrows() { // Two-phase commit: the pending user turn never landed in the transcript. // (The system message snapshot entry was there before and is still there.) - assertEquals(0, t.size(), "transcript MUST NOT contain the pending user turn after model failure"); - assertEquals( - snapshotSizeBefore, - t.snapshot().size(), - "snapshot size unchanged by the failed call"); + assertThat("transcript MUST NOT contain the pending user turn after model failure", t.size(), is(0)); + assertThat( + "snapshot size unchanged by the failed call", t.snapshot().size(), is(snapshotSizeBefore)); } @Test @@ -207,7 +209,7 @@ void existingTranscriptUntouchedWhenModelThrows() { simulateSend(t, "how are you", "i'm fine"); List before = t.snapshot(); - assertEquals(5, before.size(), "precondition: 1 system + 2 user + 2 assistant"); + assertThat("precondition: 1 system + 2 user + 2 assistant", before, hasSize(5)); // Now the model rejects a third call. assertThrows( @@ -217,7 +219,7 @@ void existingTranscriptUntouchedWhenModelThrows() { // Two-phase commit: existing transcript is byte-for-byte unchanged. List after = t.snapshot(); - assertEquals(before, after, "failed call must leave the transcript byte-for-byte unchanged"); + assertThat("failed call must leave the transcript byte-for-byte unchanged", after, is(before)); } @Test @@ -227,14 +229,14 @@ void successCommitsBothTurnsAtomically() { simulateSend(t, "hi", "hello"); - assertEquals(2, t.size(), "both turns committed"); + assertThat("both turns committed", t.size(), is(2)); // The shape is invariant: there is no API to commit only one half via appendRound. // Spot-check that the turn pair is well-formed. List snap = t.snapshot(); - assertEquals("user", snap.get(0).getRole()); - assertEquals("hi", snap.get(0).getContent()); - assertEquals("assistant", snap.get(1).getRole()); - assertEquals("hello", snap.get(1).getContent()); + assertThat(snap.get(0).getRole(), is("user")); + assertThat(snap.get(0).getContent(), is("hi")); + assertThat(snap.get(1).getRole(), is("assistant")); + assertThat(snap.get(1).getContent(), is("hello")); } @Test @@ -244,16 +246,16 @@ void streamShape() { // Phase 1: build wire format (would be passed to model.generateChat). List> wire = t.messagesWithPendingUserTurn("tell me a joke"); - assertEquals(1, wire.size(), "wire contains the pending user turn"); + assertThat("wire contains the pending user turn", wire, hasSize(1)); // Phase 2: model returned an iterable successfully — commit only the user turn. t.appendUserTurn("tell me a joke"); - assertEquals(1, t.size(), "user turn committed; assistant follows later"); + assertThat("user turn committed; assistant follows later", t.size(), is(1)); // Later: caller invoked commitStreamedReply with the accumulated text. t.appendAssistantTurn("knock knock"); - assertEquals(2, t.size(), "round closes with the assistant turn"); - assertEquals("assistant", t.snapshot().get(1).getRole()); + assertThat("round closes with the assistant turn", t.size(), is(2)); + assertThat(t.snapshot().get(1).getRole(), is("assistant")); } } } diff --git a/src/test/java/net/ladenthin/llama/CompletionResultTest.java b/src/test/java/net/ladenthin/llama/value/CompletionResultTest.java similarity index 86% rename from src/test/java/net/ladenthin/llama/CompletionResultTest.java rename to src/test/java/net/ladenthin/llama/value/CompletionResultTest.java index e361e105..e80c1cef 100644 --- a/src/test/java/net/ladenthin/llama/CompletionResultTest.java +++ b/src/test/java/net/ladenthin/llama/value/CompletionResultTest.java @@ -2,12 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.json.CompletionResponseParser; import org.junit.jupiter.api.Test; @@ -84,4 +85,14 @@ public void malformedInputYieldsEmptyResult() { assertEquals(StopReason.NONE, r.getStopReason()); assertTrue(r.getLogprobs().isEmpty()); } + + @Test + public void rawJsonAndToStringExposeContent() { + CompletionResult r = + parser.parseCompletionResult("{\"content\":\"hello world\",\"stop\":true,\"stop_type\":\"eos\"}"); + // Assert content (not just non-null) so the empty-string return mutant on getRawJson is killed. + assertTrue(r.getRawJson().contains("hello world")); + // toString() returns the generated text; pin it so the empty-string return mutant is killed. + assertEquals("hello world", r.toString()); + } } diff --git a/src/test/java/net/ladenthin/llama/ContentPartTest.java b/src/test/java/net/ladenthin/llama/value/ContentPartTest.java similarity index 75% rename from src/test/java/net/ladenthin/llama/ContentPartTest.java rename to src/test/java/net/ladenthin/llama/value/ContentPartTest.java index 8a66be70..0c0bcd84 100644 --- a/src/test/java/net/ladenthin/llama/ContentPartTest.java +++ b/src/test/java/net/ladenthin/llama/value/ContentPartTest.java @@ -2,19 +2,21 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Base64; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -27,17 +29,17 @@ public class ContentPartTest { @Test public void textPartCarriesText() { ContentPart p = ContentPart.text("hello"); - assertEquals(ContentPart.Type.TEXT, p.getType()); - assertEquals("hello", p.getText()); - assertNull(p.getImageUrl()); + assertThat(p.getType(), is(ContentPart.Type.TEXT)); + assertThat(p.getText(), is("hello")); + assertThat(p.getImageUrl(), is(nullValue())); } @Test public void imageUrlPartCarriesUrl() { ContentPart p = ContentPart.imageUrl("https://example.com/a.png"); - assertEquals(ContentPart.Type.IMAGE_URL, p.getType()); - assertEquals("https://example.com/a.png", p.getImageUrl()); - assertNull(p.getText()); + assertThat(p.getType(), is(ContentPart.Type.IMAGE_URL)); + assertThat(p.getImageUrl(), is("https://example.com/a.png")); + assertThat(p.getText(), is(nullValue())); } @Test @@ -45,7 +47,7 @@ public void imageBytesProducesDataUri() { byte[] bytes = {1, 2, 3, 4, 5}; ContentPart p = ContentPart.imageBytes(bytes, "image/png"); String expected = "data:image/png;base64," + Base64.getEncoder().encodeToString(bytes); - assertEquals(expected, p.getImageUrl()); + assertThat(p.getImageUrl(), is(expected)); } @Test @@ -78,7 +80,7 @@ public void imageFileDetectsPngMime() throws IOException { Path file = tmp.resolve("logo.PNG"); Files.write(file, new byte[] {(byte) 0x89, 0x50, 0x4E, 0x47}); ContentPart p = ContentPart.imageFile(file); - assertTrue(p.getImageUrl().startsWith("data:image/png;base64,")); + assertThat(p.getImageUrl(), startsWith("data:image/png;base64,")); } @Test @@ -86,7 +88,7 @@ public void imageFileDetectsJpegFromJpgExtension() throws IOException { Path file = tmp.resolve("photo.jpg"); Files.write(file, new byte[] {(byte) 0xFF, (byte) 0xD8, (byte) 0xFF}); ContentPart p = ContentPart.imageFile(file); - assertTrue(p.getImageUrl().startsWith("data:image/jpeg;base64,")); + assertThat(p.getImageUrl(), startsWith("data:image/jpeg;base64,")); } @Test @@ -94,7 +96,7 @@ public void imageFileDetectsJpegFromJpegExtension() throws IOException { Path file = tmp.resolve("photo.jpeg"); Files.write(file, new byte[] {(byte) 0xFF, (byte) 0xD8, (byte) 0xFF}); ContentPart p = ContentPart.imageFile(file); - assertTrue(p.getImageUrl().startsWith("data:image/jpeg;base64,")); + assertThat(p.getImageUrl(), startsWith("data:image/jpeg;base64,")); } @Test @@ -102,7 +104,7 @@ public void imageFileDetectsWebp() throws IOException { Path file = tmp.resolve("img.webp"); Files.write(file, new byte[] {0x52, 0x49, 0x46, 0x46}); ContentPart p = ContentPart.imageFile(file); - assertTrue(p.getImageUrl().startsWith("data:image/webp;base64,")); + assertThat(p.getImageUrl(), startsWith("data:image/webp;base64,")); } @Test @@ -110,7 +112,7 @@ public void imageFileDetectsGif() throws IOException { Path file = tmp.resolve("anim.gif"); Files.write(file, new byte[] {0x47, 0x49, 0x46, 0x38}); ContentPart p = ContentPart.imageFile(file); - assertTrue(p.getImageUrl().startsWith("data:image/gif;base64,")); + assertThat(p.getImageUrl(), startsWith("data:image/gif;base64,")); } @Test @@ -121,7 +123,7 @@ public void imageFileRejectsUnknownExtension() throws IOException { ContentPart.imageFile(file); fail("expected IllegalArgumentException for unknown extension"); } catch (IllegalArgumentException expected) { - assertNotNull(expected.getMessage()); + assertThat(expected.getMessage(), is(notNullValue())); } } } diff --git a/src/test/java/net/ladenthin/llama/LlamaOutputTest.java b/src/test/java/net/ladenthin/llama/value/LlamaOutputTest.java similarity index 99% rename from src/test/java/net/ladenthin/llama/LlamaOutputTest.java rename to src/test/java/net/ladenthin/llama/value/LlamaOutputTest.java index 744be815..bcc79cdc 100644 --- a/src/test/java/net/ladenthin/llama/LlamaOutputTest.java +++ b/src/test/java/net/ladenthin/llama/value/LlamaOutputTest.java @@ -3,13 +3,14 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.*; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.json.CompletionResponseParser; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/ladenthin/llama/LogLevelTest.java b/src/test/java/net/ladenthin/llama/value/LogLevelTest.java similarity index 69% rename from src/test/java/net/ladenthin/llama/LogLevelTest.java rename to src/test/java/net/ladenthin/llama/value/LogLevelTest.java index ee8c9a97..94f7cdbf 100644 --- a/src/test/java/net/ladenthin/llama/LogLevelTest.java +++ b/src/test/java/net/ladenthin/llama/value/LogLevelTest.java @@ -3,10 +3,11 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.*; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -40,11 +41,13 @@ public void testError() { } @Test - public void testOrdinalOrder() { - // Log levels must be ordered from least to most severe - assertTrue(LogLevel.DEBUG.ordinal() < LogLevel.INFO.ordinal()); - assertTrue(LogLevel.INFO.ordinal() < LogLevel.WARN.ordinal()); - assertTrue(LogLevel.WARN.ordinal() < LogLevel.ERROR.ordinal()); + public void testDeclarationOrder() { + // Declared from least to most severe; the order is part of the contract + // (mirrors llama.cpp's native log-level severity). values() returns the + // constants in declaration order, so this pins the full order without + // depending on Enum.ordinal() (Error Prone EnumOrdinal). + assertArrayEquals( + new LogLevel[] {LogLevel.DEBUG, LogLevel.INFO, LogLevel.WARN, LogLevel.ERROR}, LogLevel.values()); } @Test diff --git a/src/test/java/net/ladenthin/llama/ModelMetaTest.java b/src/test/java/net/ladenthin/llama/value/ModelMetaTest.java similarity index 70% rename from src/test/java/net/ladenthin/llama/ModelMetaTest.java rename to src/test/java/net/ladenthin/llama/value/ModelMetaTest.java index bd733de4..12ad7aca 100644 --- a/src/test/java/net/ladenthin/llama/ModelMetaTest.java +++ b/src/test/java/net/ladenthin/llama/value/ModelMetaTest.java @@ -3,11 +3,14 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; -import static org.junit.jupiter.api.Assertions.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; import com.fasterxml.jackson.databind.ObjectMapper; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; /** @@ -32,12 +35,12 @@ public void testNumericGetters() throws Exception { + "\"modalities\":{\"vision\":false,\"audio\":false}," + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); - assertEquals(1, meta.getVocabType()); - assertEquals(32016, meta.getNVocab()); - assertEquals(16384, meta.getNCtxTrain()); - assertEquals(4096, meta.getNEmbd()); - assertEquals(6738546688L, meta.getNParams()); - assertEquals(2825274880L, meta.getSize()); + assertThat(meta.getVocabType(), is(1)); + assertThat(meta.getNVocab(), is(32016)); + assertThat(meta.getNCtxTrain(), is(16384)); + assertThat(meta.getNEmbd(), is(4096)); + assertThat(meta.getNParams(), is(6738546688L)); + assertThat(meta.getSize(), is(2825274880L)); } @Test @@ -46,15 +49,15 @@ public void testModalityGetters() throws Exception { + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + "\"modalities\":{\"vision\":false,\"audio\":false}," + "\"architecture\":\"llama\",\"name\":\"\"}"); - assertFalse(textOnly.supportsVision()); - assertFalse(textOnly.supportsAudio()); + assertThat(textOnly.supportsVision(), is(false)); + assertThat(textOnly.supportsAudio(), is(false)); ModelMeta multimodal = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + "\"modalities\":{\"vision\":true,\"audio\":true}," + "\"architecture\":\"gemma3\",\"name\":\"Gemma-3\"}"); - assertTrue(multimodal.supportsVision()); - assertTrue(multimodal.supportsAudio()); + assertThat(multimodal.supportsVision(), is(true)); + assertThat(multimodal.supportsAudio(), is(true)); } @Test @@ -64,7 +67,7 @@ public void testGetArchitecture() throws Exception { + "\"modalities\":{\"vision\":false,\"audio\":false}," + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); - assertEquals("llama", meta.getArchitecture()); + assertThat(meta.getArchitecture(), is("llama")); } @Test @@ -74,7 +77,7 @@ public void testGetModelName() throws Exception { + "\"modalities\":{\"vision\":false,\"audio\":false}," + "\"architecture\":\"mistral\",\"name\":\"Mistral-7B-v0.1\"}"); - assertEquals("Mistral-7B-v0.1", meta.getModelName()); + assertThat(meta.getModelName(), is("Mistral-7B-v0.1")); } @Test @@ -83,7 +86,7 @@ public void testGetArchitectureEmptyWhenAbsent() throws Exception { + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + "\"modalities\":{\"vision\":false,\"audio\":false}}"); - assertEquals("", meta.getArchitecture()); + assertThat(meta.getArchitecture(), is("")); } @Test @@ -92,7 +95,7 @@ public void testGetModelNameEmptyWhenAbsent() throws Exception { + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + "\"modalities\":{\"vision\":false,\"audio\":false}}"); - assertEquals("", meta.getModelName()); + assertThat(meta.getModelName(), is("")); } @Test @@ -103,10 +106,20 @@ public void testGetArchitectureVariousModels() throws Exception { + "\"modalities\":{\"vision\":false,\"audio\":false}," + "\"architecture\":\"" + arch + "\",\"name\":\"\"}"); - assertEquals(arch, meta.getArchitecture()); + assertThat(meta.getArchitecture(), is(arch)); } } + @Test + public void testAsJsonReturnsBackingNode() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," + + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); + // Dereferencing the returned node kills the "return null" mutant on asJson(). + assertThat(meta.asJson().get("architecture").asText(), is("llama")); + } + @Test public void testToStringContainsNewFields() throws Exception { ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," @@ -115,9 +128,9 @@ public void testToStringContainsNewFields() throws Exception { + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); String json = meta.toString(); - assertTrue(json.contains("\"architecture\"")); - assertTrue(json.contains("\"name\"")); - assertTrue(json.contains("\"llama\"")); - assertTrue(json.contains("\"CodeLlama-7B\"")); + assertThat(json, containsString("\"architecture\"")); + assertThat(json, containsString("\"name\"")); + assertThat(json, containsString("\"llama\"")); + assertThat(json, containsString("\"CodeLlama-7B\"")); } } diff --git a/src/test/java/net/ladenthin/llama/PairTest.java b/src/test/java/net/ladenthin/llama/value/PairTest.java similarity index 99% rename from src/test/java/net/ladenthin/llama/PairTest.java rename to src/test/java/net/ladenthin/llama/value/PairTest.java index fd31efc0..fdd0a66e 100644 --- a/src/test/java/net/ladenthin/llama/PairTest.java +++ b/src/test/java/net/ladenthin/llama/value/PairTest.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.*; diff --git a/src/test/java/net/ladenthin/llama/ServerMetricsTest.java b/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java similarity index 77% rename from src/test/java/net/ladenthin/llama/ServerMetricsTest.java rename to src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java index dfff3b96..6d5989af 100644 --- a/src/test/java/net/ladenthin/llama/ServerMetricsTest.java +++ b/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java @@ -2,12 +2,13 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import com.fasterxml.jackson.databind.ObjectMapper; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated( @@ -96,4 +97,26 @@ public void missingFieldsDefaultToZero() throws Exception { assertEquals(0, m.getTokensMax()); assertEquals(0L, m.getCumulativeUsage().getTotalTokens()); } + + @Test + public void cumulativeTimingsZeroPredictedMsYieldsZeroRate() throws Exception { + // Pins the predictedMs > 0.0 boundary: with predictedN>0 but predictedMs=0 the rate must be 0.0 + // (a >= boundary mutant would divide by zero and produce a non-zero / NaN rate). + ServerMetrics m = parse("{\"n_tokens_predicted_total\":5,\"t_tokens_generation_total\":0}"); + assertEquals(0.0, m.getCumulativeTimings().getPredictedPerSecond(), 1e-9); + } + + @Test + public void asJsonExposesBackingNode() throws Exception { + ServerMetrics m = parse(SAMPLE); + // Dereferencing the returned node kills the "return null" mutant on asJson(). + assertEquals(2, m.asJson().get("idle").asInt()); + } + + @Test + public void toStringSerializesNode() throws Exception { + ServerMetrics m = parse(SAMPLE); + // Assert content (not just non-null) so the empty-string return mutant on toString is killed. + assertTrue(m.toString().contains("idle")); + } } diff --git a/src/test/java/net/ladenthin/llama/StopReasonTest.java b/src/test/java/net/ladenthin/llama/value/StopReasonTest.java similarity index 98% rename from src/test/java/net/ladenthin/llama/StopReasonTest.java rename to src/test/java/net/ladenthin/llama/value/StopReasonTest.java index 9849598a..537ef3cb 100644 --- a/src/test/java/net/ladenthin/llama/StopReasonTest.java +++ b/src/test/java/net/ladenthin/llama/value/StopReasonTest.java @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.*; diff --git a/src/test/java/net/ladenthin/llama/TimingsTest.java b/src/test/java/net/ladenthin/llama/value/TimingsTest.java similarity index 96% rename from src/test/java/net/ladenthin/llama/TimingsTest.java rename to src/test/java/net/ladenthin/llama/value/TimingsTest.java index e279163f..2470e25f 100644 --- a/src/test/java/net/ladenthin/llama/TimingsTest.java +++ b/src/test/java/net/ladenthin/llama/value/TimingsTest.java @@ -2,11 +2,12 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.assertEquals; import com.fasterxml.jackson.databind.ObjectMapper; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated(purpose = "Verify Timings.fromJson maps every result_timings field and treats missing nodes as zero.") diff --git a/src/test/java/net/ladenthin/llama/TokenLogprobTest.java b/src/test/java/net/ladenthin/llama/value/TokenLogprobTest.java similarity index 63% rename from src/test/java/net/ladenthin/llama/TokenLogprobTest.java rename to src/test/java/net/ladenthin/llama/value/TokenLogprobTest.java index 2ef36a59..ba2ca362 100644 --- a/src/test/java/net/ladenthin/llama/TokenLogprobTest.java +++ b/src/test/java/net/ladenthin/llama/value/TokenLogprobTest.java @@ -2,13 +2,20 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.json.CompletionResponseParser; import org.junit.jupiter.api.Test; @@ -23,7 +30,7 @@ public class TokenLogprobTest { @Test public void emptyWhenAbsent() { LlamaOutput out = parser.parse("{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"}"); - assertTrue(out.logprobs.isEmpty()); + assertThat(out.logprobs, is(empty())); } @Test @@ -35,14 +42,14 @@ public void parsesPostSamplingWithTopProbs() { + " {\"token\":\"Hey\",\"id\":12,\"prob\":0.05}]}" + "]}"; LlamaOutput out = parser.parse(json); - assertEquals(1, out.logprobs.size()); + assertThat(out.logprobs, hasSize(1)); TokenLogprob first = out.logprobs.get(0); - assertEquals("Hello", first.getToken()); - assertEquals(15043, first.getTokenId()); + assertThat(first.getToken(), is("Hello")); + assertThat(first.getTokenId(), is(15043)); assertEquals(0.82f, first.getLogprob(), 1e-4f); - assertEquals(2, first.getTopLogprobs().size()); - assertEquals("Hi", first.getTopLogprobs().get(0).getToken()); - assertEquals(9932, first.getTopLogprobs().get(0).getTokenId()); + assertThat(first.getTopLogprobs(), hasSize(2)); + assertThat(first.getTopLogprobs().get(0).getToken(), is("Hi")); + assertThat(first.getTopLogprobs().get(0).getTokenId(), is(9932)); assertEquals(0.10f, first.getTopLogprobs().get(0).getLogprob(), 1e-4f); } @@ -54,10 +61,10 @@ public void parsesPreSamplingWithTopLogprobs() { + "\"top_logprobs\":[{\"token\":\"Hi\",\"id\":9932,\"logprob\":-2.3}]}" + "]}"; LlamaOutput out = parser.parse(json); - assertEquals(1, out.logprobs.size()); + assertThat(out.logprobs, hasSize(1)); TokenLogprob first = out.logprobs.get(0); assertEquals(-0.20f, first.getLogprob(), 1e-4f); - assertEquals(1, first.getTopLogprobs().size()); + assertThat(first.getTopLogprobs(), hasSize(1)); assertEquals(-2.3f, first.getTopLogprobs().get(0).getLogprob(), 1e-4f); } @@ -70,9 +77,9 @@ public void preservesOrder() { + "{\"token\":\"C\",\"id\":3,\"prob\":0.1}" + "]}"; List lp = parser.parse(json).logprobs; - assertEquals("A", lp.get(0).getToken()); - assertEquals("B", lp.get(1).getToken()); - assertEquals("C", lp.get(2).getToken()); + assertThat(lp.get(0).getToken(), is("A")); + assertThat(lp.get(1).getToken(), is("B")); + assertThat(lp.get(2).getToken(), is("C")); } @Test @@ -82,15 +89,29 @@ public void mapAndListBothPopulated() { + "{\"token\":\"hello\",\"id\":1,\"prob\":0.9}" + "]}"; LlamaOutput out = parser.parse(json); - assertEquals(1, out.logprobs.size()); + assertThat(out.logprobs, hasSize(1)); assertEquals(0.9f, out.probabilities.get("hello"), 1e-4f); } + @Test + public void toStringIncludesTopLogprobCount() { + // The private @ToString.Include topLogprobsSize() is only reachable through toString(); + // rendering "top=2" kills the "return 0" primitive mutant on that helper. + TokenLogprob tl = new TokenLogprob( + "t", + 1, + 0.5f, + Arrays.asList( + new TokenLogprob("a", 2, 0.1f, Collections.emptyList()), + new TokenLogprob("b", 3, 0.2f, Collections.emptyList()))); + assertThat(tl.toString(), containsString("top=2")); + } + @Test public void backwardsCompatibleConstructor() { LlamaOutput out = new LlamaOutput("hi", java.util.Collections.emptyMap(), false, StopReason.NONE); - assertNotNull(out.logprobs); - assertTrue(out.logprobs.isEmpty()); + assertThat(out.logprobs, is(notNullValue())); + assertThat(out.logprobs, is(empty())); } } diff --git a/src/test/java/net/ladenthin/llama/value/ToolCallTest.java b/src/test/java/net/ladenthin/llama/value/ToolCallTest.java new file mode 100644 index 00000000..cbdf28e8 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/value/ToolCallTest.java @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.value; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +@ClaudeGenerated( + purpose = "Pin ToolCall's id/name/argumentsJson accessors, its hand-written function-call toString " + + "(name(args)[id]), and its Lombok value-equality so every mutation is covered.") +public class ToolCallTest { + + @Test + public void accessorsReturnConstructorValues() { + ToolCall tc = new ToolCall("call_1", "get_weather", "{\"city\":\"Berlin\"}"); + assertThat(tc.getId(), is("call_1")); + assertThat(tc.getName(), is("get_weather")); + assertThat(tc.getArgumentsJson(), is("{\"city\":\"Berlin\"}")); + } + + @Test + public void toStringRendersFunctionCallSyntax() { + // Hand-written toString: name(argsJson)[id] — assert the exact string so the + // empty-return mutant ("") and any field-omission mutant are killed. + ToolCall tc = new ToolCall("c1", "add", "{\"a\":2}"); + assertThat(tc.toString(), is("add({\"a\":2})[c1]")); + } + + @Test + public void equalsAndHashCodeAreValueBased() { + ToolCall a = new ToolCall("c1", "add", "{}"); + ToolCall b = new ToolCall("c1", "add", "{}"); + assertThat(a, is(b)); + assertThat(a.hashCode(), is(b.hashCode())); + } + + @Test + public void differingNameBreaksEquality() { + assertThat(new ToolCall("c1", "add", "{}"), is(not(new ToolCall("c1", "sub", "{}")))); + } +} diff --git a/src/test/java/net/ladenthin/llama/value/ToolDefinitionTest.java b/src/test/java/net/ladenthin/llama/value/ToolDefinitionTest.java new file mode 100644 index 00000000..8edd1577 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/value/ToolDefinitionTest.java @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.value; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +import net.ladenthin.llama.ClaudeGenerated; +import org.junit.jupiter.api.Test; + +@ClaudeGenerated( + purpose = "Pin ToolDefinition's name/description/parametersSchemaJson accessors to distinct non-empty " + + "values plus its Lombok toString/equals so every getter mutation is covered.") +public class ToolDefinitionTest { + + private static final String SCHEMA = "{\"type\":\"object\",\"properties\":{\"s\":{\"type\":\"string\"}}}"; + + @Test + public void accessorsReturnConstructorValues() { + ToolDefinition d = new ToolDefinition("echo", "Echo a string", SCHEMA); + assertThat(d.getName(), is("echo")); + // A distinct non-empty value kills the empty-string return mutant on getDescription. + assertThat(d.getDescription(), is("Echo a string")); + assertThat(d.getParametersSchemaJson(), is(SCHEMA)); + } + + @Test + public void toStringRendersAllFields() { + ToolDefinition d = new ToolDefinition("echo", "Echo a string", SCHEMA); + String s = d.toString(); + assertThat(s, containsString("echo")); + assertThat(s, containsString("Echo a string")); + } + + @Test + public void equalsAndHashCodeAreValueBased() { + ToolDefinition a = new ToolDefinition("echo", "d", "{}"); + ToolDefinition b = new ToolDefinition("echo", "d", "{}"); + assertThat(a, is(b)); + assertThat(a.hashCode(), is(b.hashCode())); + } + + @Test + public void differingDescriptionBreaksEquality() { + assertThat(new ToolDefinition("echo", "d1", "{}"), is(not(new ToolDefinition("echo", "d2", "{}")))); + } +} diff --git a/src/test/java/net/ladenthin/llama/UsageTest.java b/src/test/java/net/ladenthin/llama/value/UsageTest.java similarity index 92% rename from src/test/java/net/ladenthin/llama/UsageTest.java rename to src/test/java/net/ladenthin/llama/value/UsageTest.java index 04a7e03d..fd430809 100644 --- a/src/test/java/net/ladenthin/llama/UsageTest.java +++ b/src/test/java/net/ladenthin/llama/value/UsageTest.java @@ -2,11 +2,12 @@ // // SPDX-License-Identifier: MIT -package net.ladenthin.llama; +package net.ladenthin.llama.value; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import net.ladenthin.llama.ClaudeGenerated; import org.junit.jupiter.api.Test; @ClaudeGenerated(purpose = "Verify Usage records prompt/completion totals correctly and derives totalTokens.")