From 65a390d4f2e8045c2957fcf57291f16d379fe1ed Mon Sep 17 00:00:00 2001 From: wisdgod Date: Tue, 23 Dec 2025 11:18:28 +0800 Subject: [PATCH] 0.4.0-pre.14 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a special version (since the repository hasn't been updated for a while). It includes partial updates from 0.3 to 0.4, along with several fixes for 0.4.0-pre.13. 这是一个特殊版本(因为一段时间没有更新存储库),它包含0.3至0.4的部分更新以及对0.4.0-pre.13的几处修复。 --- .env.example | 80 +- .gitignore | 6 + .rust-toolchain.toml | 2 + .rustfmt.toml | 92 +- Cargo.toml | 242 +- Dockerfile | 2 +- LICENSE | 6 - LICENSE-APACHE | 202 -- LICENSE-MIT | 7 - LICENSE.md | 20 + README.md | 1687 +++++++----- VERSION | 2 +- build.rs | 178 +- build_info.rs | 201 +- crates/grpc-stream/Cargo.toml | 13 + crates/grpc-stream/src/buffer.rs | 195 ++ crates/grpc-stream/src/compression.rs | 154 ++ crates/grpc-stream/src/decoder.rs | 135 + crates/grpc-stream/src/frame.rs | 54 + crates/grpc-stream/src/lib.rs | 46 + crates/interned/Cargo.toml | 54 + crates/interned/src/arc_str.rs | 1119 ++++++++ crates/interned/src/lib.rs | 29 + crates/interned/src/str.rs | 1248 +++++++++ crates/manually_init/Cargo.toml | 17 + crates/manually_init/src/lib.rs | 363 +++ crates/rep_move/Cargo.toml | 10 + crates/rep_move/src/lib.rs | 246 ++ patch/chrono-0.4.41/tests/dateutils.rs | 165 -- patch/chrono-0.4.41/tests/wasm.rs | 89 - patch/chrono-0.4.41/tests/win_bindings.rs | 28 - .../CITATION.cff | 0 .../Cargo.toml | 12 +- .../LICENSE.txt | 0 .../README.md | 0 .../src/date.rs | 6 +- .../src/datetime/mod.rs | 65 +- .../src/datetime/serde.rs | 20 +- .../src/datetime/tests.rs | 37 +- .../src/format/formatting.rs | 29 +- .../src/format/locales.rs | 0 .../src/format/mod.rs | 6 +- .../src/format/parse.rs | 2 +- .../src/format/parsed.rs | 2 +- .../src/format/scan.rs | 0 .../src/format/strftime.rs | 621 +++-- .../src/lib.rs | 9 +- .../src/month.rs | 7 +- .../src/naive/date/mod.rs | 54 +- .../src/naive/date/tests.rs | 45 + .../src/naive/datetime/mod.rs | 4 +- .../src/naive/datetime/serde.rs | 8 +- .../src/naive/datetime/tests.rs | 0 .../src/naive/internals.rs | 0 .../src/naive/isoweek.rs | 8 +- .../src/naive/mod.rs | 0 .../src/naive/time/mod.rs | 4 +- .../src/naive/time/serde.rs | 0 .../src/naive/time/tests.rs | 17 +- .../src/offset/fixed.rs | 4 +- .../src/offset/local/mod.rs | 16 +- .../src/offset/local/tz_data.rs | 0 .../src/offset/local/tz_info/mod.rs | 0 .../src/offset/local/tz_info/parser.rs | 0 .../src/offset/local/tz_info/rule.rs | 0 .../src/offset/local/tz_info/timezone.rs | 2 +- .../src/offset/local/unix.rs | 0 .../src/offset/local/win_bindings.rs | 0 .../src/offset/local/win_bindings.txt | 0 .../src/offset/local/windows.rs | 0 .../src/offset/mod.rs | 2 +- .../src/offset/utc.rs | 4 +- .../src/round.rs | 14 +- .../src/time_delta.rs | 8 +- .../src/traits.rs | 12 +- .../src/weekday.rs | 7 +- .../src/weekday_set.rs | 0 patch/dotenvy-0.15.7/src/iter.rs | 15 +- patch/macros/Cargo.toml | 10 + patch/macros/src/lib.rs | 105 + patch/prost-0.14.1/Cargo.toml | 87 + patch/prost-0.14.1/Cargo.toml.orig | 35 + patch/prost-0.14.1/LICENSE | 201 ++ patch/prost-0.14.1/README.md | 507 ++++ patch/prost-0.14.1/src/byte_str.rs | 366 +++ patch/prost-0.14.1/src/encoding.rs | 1471 +++++++++++ .../prost-0.14.1/src/encoding/fixed_width.rs | 31 + .../src/encoding/length_delimiter.rs | 46 + patch/prost-0.14.1/src/encoding/utf8.rs | 216 ++ patch/prost-0.14.1/src/encoding/utf8/ascii.rs | 1847 +++++++++++++ .../src/encoding/utf8/simd_funcs.rs | 347 +++ patch/prost-0.14.1/src/encoding/varint.rs | 667 +++++ patch/prost-0.14.1/src/encoding/wire_type.rs | 70 + patch/prost-0.14.1/src/error.rs | 180 ++ patch/prost-0.14.1/src/lib.rs | 54 + patch/prost-0.14.1/src/message.rs | 184 ++ patch/prost-0.14.1/src/name.rs | 34 + patch/prost-0.14.1/src/types.rs | 573 ++++ patch/prost-derive/Cargo.toml | 25 + patch/prost-derive/LICENSE | 201 ++ patch/prost-derive/README.md | 16 + patch/prost-derive/src/field/group.rs | 137 + patch/prost-derive/src/field/map.rs | 411 +++ patch/prost-derive/src/field/message.rs | 134 + patch/prost-derive/src/field/mod.rs | 356 +++ patch/prost-derive/src/field/oneof.rs | 90 + patch/prost-derive/src/field/scalar.rs | 842 ++++++ patch/prost-derive/src/lib.rs | 691 +++++ patch/prost-types/Cargo.toml | 34 + patch/prost-types/Cargo.toml.orig | 29 + patch/prost-types/LICENSE | 0 patch/prost-types/README.md | 21 + patch/prost-types/src/any.rs | 69 + patch/prost-types/src/compiler.rs | 175 ++ patch/prost-types/src/conversions.rs | 62 + patch/prost-types/src/datetime.rs | 1024 ++++++++ patch/prost-types/src/duration.rs | 481 ++++ patch/prost-types/src/lib.rs | 84 + patch/prost-types/src/protobuf.rs | 2319 +++++++++++++++++ patch/prost-types/src/timestamp.rs | 431 +++ patch/prost-types/src/type_url.rs | 70 + patch/reqwest-0.12.18/CHANGELOG.md | 17 + patch/reqwest-0.12.18/Cargo.toml | 9 +- patch/reqwest-0.12.18/README.md | 2 +- patch/reqwest-0.12.18/src/async_impl/body.rs | 38 +- .../reqwest-0.12.18/src/async_impl/client.rs | 645 ++--- .../src/async_impl/h3_client/mod.rs | 91 +- .../src/async_impl/multipart.rs | 184 +- .../reqwest-0.12.18/src/async_impl/request.rs | 222 +- .../src/async_impl/response.rs | 16 +- patch/reqwest-0.12.18/src/blocking/client.rs | 221 +- .../reqwest-0.12.18/src/blocking/multipart.rs | 140 +- patch/reqwest-0.12.18/src/config.rs | 19 +- patch/reqwest-0.12.18/src/connect.rs | 459 +++- patch/reqwest-0.12.18/src/cookie.rs | 154 +- patch/reqwest-0.12.18/src/dns/hickory.rs | 14 +- patch/reqwest-0.12.18/src/dns/mod.rs | 3 + patch/reqwest-0.12.18/src/dns/resolve.rs | 59 +- patch/reqwest-0.12.18/src/lib.rs | 11 +- patch/reqwest-0.12.18/src/proxy.rs | 97 +- patch/reqwest-0.12.18/src/retry.rs | 447 ++++ patch/reqwest-0.12.18/src/util.rs | 47 +- patch/reqwest-0.12.18/src/wasm/client.rs | 85 +- patch/reqwest-0.12.18/src/wasm/request.rs | 136 +- patch/rkyv-0.8.12/Cargo.toml | 257 ++ patch/rkyv-0.8.12/Cargo.toml.orig | 87 + patch/rkyv-0.8.12/LICENSE | 7 + patch/rkyv-0.8.12/README.md | 98 + .../rkyv-0.8.12/examples/backwards_compat.rs | 85 + .../examples/complex_wrapper_types.rs | 195 ++ .../examples/derive_partial_ord.rs | 42 + .../examples/explicit_enum_discriminants.rs | 14 + .../rkyv-0.8.12/examples/json_like_schema.rs | 115 + patch/rkyv-0.8.12/examples/niching.rs | 136 + patch/rkyv-0.8.12/examples/readme.rs | 47 + patch/rkyv-0.8.12/examples/remote_types.rs | 147 ++ patch/rkyv-0.8.12/src/_macros.rs | 22 + patch/rkyv-0.8.12/src/alias.rs | 32 + patch/rkyv-0.8.12/src/api/checked.rs | 192 ++ patch/rkyv-0.8.12/src/api/high/checked.rs | 258 ++ patch/rkyv-0.8.12/src/api/high/mod.rs | 313 +++ patch/rkyv-0.8.12/src/api/low/checked.rs | 313 +++ patch/rkyv-0.8.12/src/api/low/mod.rs | 193 ++ patch/rkyv-0.8.12/src/api/mod.rs | 416 +++ .../rkyv-0.8.12/src/api/test/inner_checked.rs | 65 + .../src/api/test/inner_unchecked.rs | 54 + patch/rkyv-0.8.12/src/api/test/mod.rs | 31 + patch/rkyv-0.8.12/src/api/test/outer_high.rs | 32 + patch/rkyv-0.8.12/src/api/test/outer_low.rs | 41 + patch/rkyv-0.8.12/src/boxed.rs | 192 ++ .../src/collections/btree/map/iter.rs | 695 +++++ .../src/collections/btree/map/mod.rs | 1219 +++++++++ .../rkyv-0.8.12/src/collections/btree/mod.rs | 2 + .../rkyv-0.8.12/src/collections/btree/set.rs | 151 ++ patch/rkyv-0.8.12/src/collections/mod.rs | 6 + .../src/collections/swiss_table/index_map.rs | 528 ++++ .../src/collections/swiss_table/index_set.rs | 142 + .../src/collections/swiss_table/map.rs | 421 +++ .../src/collections/swiss_table/mod.rs | 13 + .../src/collections/swiss_table/set.rs | 117 + .../src/collections/swiss_table/table.rs | 712 +++++ patch/rkyv-0.8.12/src/collections/util.rs | 108 + patch/rkyv-0.8.12/src/de/mod.rs | 6 + patch/rkyv-0.8.12/src/de/pooling/alloc.rs | 116 + patch/rkyv-0.8.12/src/de/pooling/core.rs | 21 + patch/rkyv-0.8.12/src/de/pooling/mod.rs | 281 ++ patch/rkyv-0.8.12/src/ffi.rs | 197 ++ patch/rkyv-0.8.12/src/hash.rs | 125 + patch/rkyv-0.8.12/src/impls/alloc/boxed.rs | 171 ++ .../src/impls/alloc/collections/btree_map.rs | 312 +++ .../src/impls/alloc/collections/btree_set.rs | 134 + .../src/impls/alloc/collections/mod.rs | 3 + .../src/impls/alloc/collections/vec_deque.rs | 117 + patch/rkyv-0.8.12/src/impls/alloc/ffi.rs | 87 + patch/rkyv-0.8.12/src/impls/alloc/mod.rs | 7 + .../rkyv-0.8.12/src/impls/alloc/rc/atomic.rs | 154 ++ patch/rkyv-0.8.12/src/impls/alloc/rc/mod.rs | 465 ++++ patch/rkyv-0.8.12/src/impls/alloc/string.rs | 96 + patch/rkyv-0.8.12/src/impls/alloc/vec.rs | 105 + patch/rkyv-0.8.12/src/impls/alloc/with.rs | 834 ++++++ patch/rkyv-0.8.12/src/impls/core/ffi.rs | 88 + patch/rkyv-0.8.12/src/impls/core/mod.rs | 574 ++++ patch/rkyv-0.8.12/src/impls/core/net.rs | 520 ++++ patch/rkyv-0.8.12/src/impls/core/ops.rs | 443 ++++ patch/rkyv-0.8.12/src/impls/core/option.rs | 99 + patch/rkyv-0.8.12/src/impls/core/primitive.rs | 458 ++++ patch/rkyv-0.8.12/src/impls/core/result.rs | 111 + patch/rkyv-0.8.12/src/impls/core/time.rs | 89 + .../src/impls/core/with/atomic/_macros.rs | 17 + .../src/impls/core/with/atomic/mod.rs | 83 + .../src/impls/core/with/atomic/multibyte.rs | 169 ++ patch/rkyv-0.8.12/src/impls/core/with/mod.rs | 1257 +++++++++ .../src/impls/core/with/niching.rs | 373 +++ .../rkyv-0.8.12/src/impls/ext/arrayvec_0_7.rs | 89 + patch/rkyv-0.8.12/src/impls/ext/bytes_1.rs | 55 + .../src/impls/ext/hashbrown_0_14/hash_map.rs | 122 + .../src/impls/ext/hashbrown_0_14/hash_set.rs | 117 + .../src/impls/ext/hashbrown_0_14/mod.rs | 3 + .../src/impls/ext/hashbrown_0_14/with.rs | 136 + .../src/impls/ext/hashbrown_0_16/hash_map.rs | 122 + .../src/impls/ext/hashbrown_0_16/hash_set.rs | 117 + .../src/impls/ext/hashbrown_0_16/mod.rs | 6 + .../src/impls/ext/hashbrown_0_16/with.rs | 136 + .../src/impls/ext/indexmap_2/index_map.rs | 110 + .../src/impls/ext/indexmap_2/index_set.rs | 94 + .../src/impls/ext/indexmap_2/mod.rs | 2 + patch/rkyv-0.8.12/src/impls/ext/mod.rs | 32 + patch/rkyv-0.8.12/src/impls/ext/smallvec_1.rs | 92 + .../rkyv-0.8.12/src/impls/ext/smolstr_0_2.rs | 60 + .../rkyv-0.8.12/src/impls/ext/smolstr_0_3.rs | 60 + .../rkyv-0.8.12/src/impls/ext/thin_vec_0_2.rs | 104 + patch/rkyv-0.8.12/src/impls/ext/tinyvec_1.rs | 223 ++ .../rkyv-0.8.12/src/impls/ext/triomphe_0_1.rs | 93 + patch/rkyv-0.8.12/src/impls/ext/uuid_1.rs | 51 + patch/rkyv-0.8.12/src/impls/mod.rs | 807 ++++++ patch/rkyv-0.8.12/src/impls/rend.rs | 183 ++ .../src/impls/std/collections/hash_map.rs | 299 +++ .../src/impls/std/collections/hash_set.rs | 125 + .../src/impls/std/collections/mod.rs | 2 + patch/rkyv-0.8.12/src/impls/std/mod.rs | 3 + patch/rkyv-0.8.12/src/impls/std/net.rs | 30 + patch/rkyv-0.8.12/src/impls/std/with.rs | 700 +++++ patch/rkyv-0.8.12/src/lib.rs | 327 +++ patch/rkyv-0.8.12/src/net.rs | 459 ++++ patch/rkyv-0.8.12/src/niche/mod.rs | 6 + patch/rkyv-0.8.12/src/niche/niched_option.rs | 218 ++ patch/rkyv-0.8.12/src/niche/niching.rs | 101 + patch/rkyv-0.8.12/src/niche/option_box.rs | 256 ++ patch/rkyv-0.8.12/src/niche/option_nonzero.rs | 242 ++ patch/rkyv-0.8.12/src/ops.rs | 286 ++ patch/rkyv-0.8.12/src/option.rs | 316 +++ patch/rkyv-0.8.12/src/place.rs | 177 ++ patch/rkyv-0.8.12/src/polyfill.rs | 14 + patch/rkyv-0.8.12/src/primitive.rs | 210 ++ patch/rkyv-0.8.12/src/rc.rs | 398 +++ patch/rkyv-0.8.12/src/rel_ptr.rs | 643 +++++ patch/rkyv-0.8.12/src/result.rs | 200 ++ patch/rkyv-0.8.12/src/seal.rs | 119 + patch/rkyv-0.8.12/src/ser/allocator/alloc.rs | 323 +++ patch/rkyv-0.8.12/src/ser/allocator/core.rs | 100 + patch/rkyv-0.8.12/src/ser/allocator/mod.rs | 245 ++ patch/rkyv-0.8.12/src/ser/mod.rs | 91 + patch/rkyv-0.8.12/src/ser/sharing/alloc.rs | 96 + patch/rkyv-0.8.12/src/ser/sharing/core.rs | 16 + patch/rkyv-0.8.12/src/ser/sharing/mod.rs | 112 + patch/rkyv-0.8.12/src/ser/writer/alloc.rs | 33 + patch/rkyv-0.8.12/src/ser/writer/core.rs | 205 ++ patch/rkyv-0.8.12/src/ser/writer/mod.rs | 204 ++ patch/rkyv-0.8.12/src/ser/writer/std.rs | 87 + patch/rkyv-0.8.12/src/simd/generic.rs | 106 + patch/rkyv-0.8.12/src/simd/mod.rs | 69 + patch/rkyv-0.8.12/src/simd/neon.rs | 94 + patch/rkyv-0.8.12/src/simd/sse2.rs | 76 + patch/rkyv-0.8.12/src/string/mod.rs | 304 +++ patch/rkyv-0.8.12/src/string/repr.rs | 301 +++ patch/rkyv-0.8.12/src/time.rs | 190 ++ patch/rkyv-0.8.12/src/traits.rs | 483 ++++ patch/rkyv-0.8.12/src/tuple.rs | 142 + .../rkyv-0.8.12/src/util/alloc/aligned_vec.rs | 1033 ++++++++ patch/rkyv-0.8.12/src/util/alloc/arena.rs | 122 + patch/rkyv-0.8.12/src/util/alloc/mod.rs | 4 + patch/rkyv-0.8.12/src/util/inline_vec.rs | 407 +++ patch/rkyv-0.8.12/src/util/mod.rs | 36 + patch/rkyv-0.8.12/src/util/ser_vec.rs | 400 +++ .../rkyv-0.8.12/src/validation/archive/mod.rs | 148 ++ .../src/validation/archive/validator.rs | 188 ++ patch/rkyv-0.8.12/src/validation/mod.rs | 306 +++ .../rkyv-0.8.12/src/validation/shared/mod.rs | 69 + .../src/validation/shared/validator.rs | 137 + patch/rkyv-0.8.12/src/vec.rs | 315 +++ patch/rkyv-0.8.12/src/with.rs | 697 +++++ scripts/minify.js | 78 +- scripts/package-lock.json | 32 +- src/app.rs | 3 +- src/app/config.rs | 60 +- src/app/constant.rs | 321 +-- src/app/constant/header.rs | 150 +- src/app/constant/header/version.rs | 51 +- src/app/constant/status.rs | 2 +- src/app/frontend.rs | 684 +++++ src/app/lazy.rs | 342 ++- src/app/lazy/log.rs | 313 ++- src/app/lazy/path.rs | 37 + src/app/model.rs | 610 +++-- src/app/model/alias.rs | 20 +- src/app/model/build_key.rs | 23 +- src/app/model/checksum.rs | 46 +- src/app/model/config.rs | 234 +- src/app/model/context_fill_mode.rs | 70 + src/app/model/cpp.rs | 63 +- src/app/model/default_instructions.rs | 96 + src/app/model/exchange_map.rs | 49 + src/app/model/fetch_model.rs | 6 +- src/app/model/hash.rs | 2 +- src/app/model/id_source.rs | 62 + src/app/model/log.rs | 251 +- src/app/model/log/command.rs | 78 + src/app/model/log/limit.rs | 33 + src/app/model/log/manager.rs | 436 ++++ src/app/model/log/storage.rs | 31 + src/app/model/proxy.rs | 4 +- src/app/model/proxy_pool.rs | 297 +-- src/app/model/proxy_pool/proxy_url.rs | 51 +- src/app/model/state.rs | 60 +- src/app/model/state/log.rs | 158 +- src/app/model/state/page.rs | 26 +- src/app/model/state/token.rs | 307 ++- src/app/model/state/token/queue.rs | 290 +++ src/app/model/timestamp_header.rs | 159 +- src/app/model/token.rs | 314 +-- src/app/model/token/cache.rs | 155 +- src/app/model/token/provider.rs | 21 +- src/app/model/tz.rs | 49 +- src/app/model/usage_check.rs | 48 +- src/app/model/version.rs | 196 ++ src/app/route.rs | 174 ++ src/common.rs | 6 +- src/common/client.rs | 445 ++-- src/common/model.rs | 68 +- src/common/model/config.rs | 7 +- src/common/model/error.rs | 17 +- src/common/model/health.rs | 7 +- src/common/model/ntp.rs | 435 ++++ src/common/model/stringify.rs | 253 ++ src/common/model/token.rs | 49 +- src/common/model/tri.rs | 24 +- src/common/model/userinfo.rs | 280 +- src/common/model/userinfo/limit_type.rs | 51 + src/common/model/userinfo/membership_type.rs | 73 + src/common/model/userinfo/payment_id.rs | 31 +- src/common/model/userinfo/privacy_mode.rs | 51 + .../model/userinfo/subscription_status.rs | 8 +- src/common/model/userinfo/usage_event.rs | 162 ++ src/common/model/userinfo/usage_info.rs | 69 + src/common/time.rs | 339 +-- src/common/utils.rs | 926 +++---- src/common/utils/base62.rs | 371 ++- src/common/utils/base64.rs | 335 +-- src/common/utils/duration_fmt.rs | 904 +++---- src/common/utils/hex.rs | 92 +- src/common/utils/proto_encode.rs | 232 ++ src/common/utils/string_builder.rs | 12 +- src/core.rs | 2 +- src/core/adapter.rs | 121 +- src/core/adapter/anthropic.rs | 871 ++++--- src/core/adapter/error.rs | 180 ++ src/core/adapter/openai.rs | 559 ++-- src/core/aiserver/v1.rs | 97 +- src/core/aiserver/v1/aiserver.v1.rs | 1230 ++------- src/core/aiserver/v1/lite.proto | 417 +-- src/core/auth.rs | 11 + src/core/auth/error.rs | 112 + src/core/auth/middleware.rs | 152 ++ src/core/auth/model.rs | 5 + src/core/auth/utils.rs | 125 + src/core/config.rs | 87 +- src/core/config/key.rs | 214 +- src/core/constant.rs | 315 +-- src/core/constant/display_name.rs | 14 +- src/core/constant/display_name/formatter.rs | 3 +- src/core/constant/display_name/parser.rs | 37 +- src/core/constant/display_name/tokenizer.rs | 13 +- src/core/error.rs | 15 +- src/core/error/canonical.rs | 103 +- src/core/error/cpp.rs | 20 + src/core/error/cursor.rs | 24 +- src/core/middleware/auth.rs | 197 +- src/core/model.rs | 205 +- src/core/model/anthropic.rs | 615 ++--- src/core/model/openai.rs | 269 +- src/core/model/resolver.rs | 21 +- src/core/model/tool_id_parser.rs | 38 + src/core/route.rs | 42 +- src/core/route/gen.rs | 39 +- src/core/route/health.rs | 225 +- src/core/route/logs.rs | 290 +-- src/core/route/page.rs | 236 +- src/core/route/proxies.rs | 5 +- src/core/route/token.rs | 235 +- src/core/route/tokens.rs | 255 +- src/core/route/userinfo.rs | 38 - src/core/route/utils.rs | 72 + src/core/service.rs | 2125 +++++++-------- src/core/service/cpp.rs | 445 ++-- src/core/stream.rs | 1 - src/core/stream/decoder.rs | 507 ++-- src/core/stream/decoder/cpp.rs | 435 ++-- src/core/stream/decoder/direct.rs | 3 +- src/core/stream/decoder/utils.rs | 55 +- src/core/stream/droppable.rs | 20 +- src/leak.rs | 42 +- src/leak/arc.rs | 150 +- src/leak/manually_init.rs | 337 ++- src/lib.rs | 105 +- src/main.rs | 360 +-- src/natural_args.rs | 278 +- static/api.html | 2 +- static/tokens.html | 160 +- tools/get-token/Cargo.toml | 17 - tools/get-token/README.md | 58 - tools/get-token/src/main.rs | 153 -- tools/next_reload/Cargo.toml | 23 - tools/next_reload/src/main.rs | 55 - tools/reset-telemetry/Cargo.toml | 17 - tools/reset-telemetry/src/main.rs | 74 - tools/rkyv_adapter/src/main.rs | 430 --- tools/set-token/Cargo.toml | 16 - tools/set-token/src/main.rs | 538 ---- 428 files changed, 66005 insertions(+), 15324 deletions(-) create mode 100644 .rust-toolchain.toml delete mode 100644 LICENSE delete mode 100644 LICENSE-APACHE delete mode 100644 LICENSE-MIT create mode 100644 LICENSE.md create mode 100644 crates/grpc-stream/Cargo.toml create mode 100644 crates/grpc-stream/src/buffer.rs create mode 100644 crates/grpc-stream/src/compression.rs create mode 100644 crates/grpc-stream/src/decoder.rs create mode 100644 crates/grpc-stream/src/frame.rs create mode 100644 crates/grpc-stream/src/lib.rs create mode 100644 crates/interned/Cargo.toml create mode 100644 crates/interned/src/arc_str.rs create mode 100644 crates/interned/src/lib.rs create mode 100644 crates/interned/src/str.rs create mode 100644 crates/manually_init/Cargo.toml create mode 100644 crates/manually_init/src/lib.rs create mode 100644 crates/rep_move/Cargo.toml create mode 100644 crates/rep_move/src/lib.rs delete mode 100644 patch/chrono-0.4.41/tests/dateutils.rs delete mode 100644 patch/chrono-0.4.41/tests/wasm.rs delete mode 100644 patch/chrono-0.4.41/tests/win_bindings.rs rename patch/{chrono-0.4.41 => chrono-0.4.42}/CITATION.cff (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/Cargo.toml (88%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/LICENSE.txt (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/README.md (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/date.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/datetime/mod.rs (96%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/datetime/serde.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/datetime/tests.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/formatting.rs (97%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/locales.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/mod.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/parse.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/parsed.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/scan.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/format/strftime.rs (79%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/lib.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/month.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/date/mod.rs (97%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/date/tests.rs (93%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/datetime/mod.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/datetime/serde.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/datetime/tests.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/internals.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/isoweek.rs (95%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/mod.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/time/mod.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/time/serde.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/naive/time/tests.rs (96%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/fixed.rs (97%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/mod.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/tz_data.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/tz_info/mod.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/tz_info/parser.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/tz_info/rule.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/tz_info/timezone.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/unix.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/win_bindings.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/win_bindings.txt (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/local/windows.rs (100%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/mod.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/offset/utc.rs (95%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/round.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/time_delta.rs (99%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/traits.rs (98%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/weekday.rs (97%) rename patch/{chrono-0.4.41 => chrono-0.4.42}/src/weekday_set.rs (100%) create mode 100644 patch/macros/Cargo.toml create mode 100644 patch/macros/src/lib.rs create mode 100644 patch/prost-0.14.1/Cargo.toml create mode 100644 patch/prost-0.14.1/Cargo.toml.orig create mode 100644 patch/prost-0.14.1/LICENSE create mode 100644 patch/prost-0.14.1/README.md create mode 100644 patch/prost-0.14.1/src/byte_str.rs create mode 100644 patch/prost-0.14.1/src/encoding.rs create mode 100644 patch/prost-0.14.1/src/encoding/fixed_width.rs create mode 100644 patch/prost-0.14.1/src/encoding/length_delimiter.rs create mode 100644 patch/prost-0.14.1/src/encoding/utf8.rs create mode 100644 patch/prost-0.14.1/src/encoding/utf8/ascii.rs create mode 100644 patch/prost-0.14.1/src/encoding/utf8/simd_funcs.rs create mode 100644 patch/prost-0.14.1/src/encoding/varint.rs create mode 100644 patch/prost-0.14.1/src/encoding/wire_type.rs create mode 100644 patch/prost-0.14.1/src/error.rs create mode 100644 patch/prost-0.14.1/src/lib.rs create mode 100644 patch/prost-0.14.1/src/message.rs create mode 100644 patch/prost-0.14.1/src/name.rs create mode 100644 patch/prost-0.14.1/src/types.rs create mode 100644 patch/prost-derive/Cargo.toml create mode 100644 patch/prost-derive/LICENSE create mode 100644 patch/prost-derive/README.md create mode 100644 patch/prost-derive/src/field/group.rs create mode 100644 patch/prost-derive/src/field/map.rs create mode 100644 patch/prost-derive/src/field/message.rs create mode 100644 patch/prost-derive/src/field/mod.rs create mode 100644 patch/prost-derive/src/field/oneof.rs create mode 100644 patch/prost-derive/src/field/scalar.rs create mode 100644 patch/prost-derive/src/lib.rs create mode 100644 patch/prost-types/Cargo.toml create mode 100644 patch/prost-types/Cargo.toml.orig create mode 100644 patch/prost-types/LICENSE create mode 100644 patch/prost-types/README.md create mode 100644 patch/prost-types/src/any.rs create mode 100644 patch/prost-types/src/compiler.rs create mode 100644 patch/prost-types/src/conversions.rs create mode 100644 patch/prost-types/src/datetime.rs create mode 100644 patch/prost-types/src/duration.rs create mode 100644 patch/prost-types/src/lib.rs create mode 100644 patch/prost-types/src/protobuf.rs create mode 100644 patch/prost-types/src/timestamp.rs create mode 100644 patch/prost-types/src/type_url.rs create mode 100644 patch/reqwest-0.12.18/src/retry.rs create mode 100644 patch/rkyv-0.8.12/Cargo.toml create mode 100644 patch/rkyv-0.8.12/Cargo.toml.orig create mode 100644 patch/rkyv-0.8.12/LICENSE create mode 100644 patch/rkyv-0.8.12/README.md create mode 100644 patch/rkyv-0.8.12/examples/backwards_compat.rs create mode 100644 patch/rkyv-0.8.12/examples/complex_wrapper_types.rs create mode 100644 patch/rkyv-0.8.12/examples/derive_partial_ord.rs create mode 100644 patch/rkyv-0.8.12/examples/explicit_enum_discriminants.rs create mode 100644 patch/rkyv-0.8.12/examples/json_like_schema.rs create mode 100644 patch/rkyv-0.8.12/examples/niching.rs create mode 100644 patch/rkyv-0.8.12/examples/readme.rs create mode 100644 patch/rkyv-0.8.12/examples/remote_types.rs create mode 100644 patch/rkyv-0.8.12/src/_macros.rs create mode 100644 patch/rkyv-0.8.12/src/alias.rs create mode 100644 patch/rkyv-0.8.12/src/api/checked.rs create mode 100644 patch/rkyv-0.8.12/src/api/high/checked.rs create mode 100644 patch/rkyv-0.8.12/src/api/high/mod.rs create mode 100644 patch/rkyv-0.8.12/src/api/low/checked.rs create mode 100644 patch/rkyv-0.8.12/src/api/low/mod.rs create mode 100644 patch/rkyv-0.8.12/src/api/mod.rs create mode 100644 patch/rkyv-0.8.12/src/api/test/inner_checked.rs create mode 100644 patch/rkyv-0.8.12/src/api/test/inner_unchecked.rs create mode 100644 patch/rkyv-0.8.12/src/api/test/mod.rs create mode 100644 patch/rkyv-0.8.12/src/api/test/outer_high.rs create mode 100644 patch/rkyv-0.8.12/src/api/test/outer_low.rs create mode 100644 patch/rkyv-0.8.12/src/boxed.rs create mode 100644 patch/rkyv-0.8.12/src/collections/btree/map/iter.rs create mode 100644 patch/rkyv-0.8.12/src/collections/btree/map/mod.rs create mode 100644 patch/rkyv-0.8.12/src/collections/btree/mod.rs create mode 100644 patch/rkyv-0.8.12/src/collections/btree/set.rs create mode 100644 patch/rkyv-0.8.12/src/collections/mod.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/index_map.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/index_set.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/map.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/mod.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/set.rs create mode 100644 patch/rkyv-0.8.12/src/collections/swiss_table/table.rs create mode 100644 patch/rkyv-0.8.12/src/collections/util.rs create mode 100644 patch/rkyv-0.8.12/src/de/mod.rs create mode 100644 patch/rkyv-0.8.12/src/de/pooling/alloc.rs create mode 100644 patch/rkyv-0.8.12/src/de/pooling/core.rs create mode 100644 patch/rkyv-0.8.12/src/de/pooling/mod.rs create mode 100644 patch/rkyv-0.8.12/src/ffi.rs create mode 100644 patch/rkyv-0.8.12/src/hash.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/boxed.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/collections/btree_map.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/collections/btree_set.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/collections/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/collections/vec_deque.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/ffi.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/rc/atomic.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/rc/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/string.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/vec.rs create mode 100644 patch/rkyv-0.8.12/src/impls/alloc/with.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/ffi.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/net.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/ops.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/option.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/primitive.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/result.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/time.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/with/atomic/_macros.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/with/atomic/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/with/atomic/multibyte.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/with/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/core/with/niching.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/arrayvec_0_7.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/bytes_1.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_14/hash_map.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_14/hash_set.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_14/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_14/with.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_16/hash_map.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_16/hash_set.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_16/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/hashbrown_0_16/with.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/indexmap_2/index_map.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/indexmap_2/index_set.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/indexmap_2/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/smallvec_1.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/smolstr_0_2.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/smolstr_0_3.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/thin_vec_0_2.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/tinyvec_1.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/triomphe_0_1.rs create mode 100644 patch/rkyv-0.8.12/src/impls/ext/uuid_1.rs create mode 100644 patch/rkyv-0.8.12/src/impls/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/rend.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/collections/hash_map.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/collections/hash_set.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/collections/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/mod.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/net.rs create mode 100644 patch/rkyv-0.8.12/src/impls/std/with.rs create mode 100644 patch/rkyv-0.8.12/src/lib.rs create mode 100644 patch/rkyv-0.8.12/src/net.rs create mode 100644 patch/rkyv-0.8.12/src/niche/mod.rs create mode 100644 patch/rkyv-0.8.12/src/niche/niched_option.rs create mode 100644 patch/rkyv-0.8.12/src/niche/niching.rs create mode 100644 patch/rkyv-0.8.12/src/niche/option_box.rs create mode 100644 patch/rkyv-0.8.12/src/niche/option_nonzero.rs create mode 100644 patch/rkyv-0.8.12/src/ops.rs create mode 100644 patch/rkyv-0.8.12/src/option.rs create mode 100644 patch/rkyv-0.8.12/src/place.rs create mode 100644 patch/rkyv-0.8.12/src/polyfill.rs create mode 100644 patch/rkyv-0.8.12/src/primitive.rs create mode 100644 patch/rkyv-0.8.12/src/rc.rs create mode 100644 patch/rkyv-0.8.12/src/rel_ptr.rs create mode 100644 patch/rkyv-0.8.12/src/result.rs create mode 100644 patch/rkyv-0.8.12/src/seal.rs create mode 100644 patch/rkyv-0.8.12/src/ser/allocator/alloc.rs create mode 100644 patch/rkyv-0.8.12/src/ser/allocator/core.rs create mode 100644 patch/rkyv-0.8.12/src/ser/allocator/mod.rs create mode 100644 patch/rkyv-0.8.12/src/ser/mod.rs create mode 100644 patch/rkyv-0.8.12/src/ser/sharing/alloc.rs create mode 100644 patch/rkyv-0.8.12/src/ser/sharing/core.rs create mode 100644 patch/rkyv-0.8.12/src/ser/sharing/mod.rs create mode 100644 patch/rkyv-0.8.12/src/ser/writer/alloc.rs create mode 100644 patch/rkyv-0.8.12/src/ser/writer/core.rs create mode 100644 patch/rkyv-0.8.12/src/ser/writer/mod.rs create mode 100644 patch/rkyv-0.8.12/src/ser/writer/std.rs create mode 100644 patch/rkyv-0.8.12/src/simd/generic.rs create mode 100644 patch/rkyv-0.8.12/src/simd/mod.rs create mode 100644 patch/rkyv-0.8.12/src/simd/neon.rs create mode 100644 patch/rkyv-0.8.12/src/simd/sse2.rs create mode 100644 patch/rkyv-0.8.12/src/string/mod.rs create mode 100644 patch/rkyv-0.8.12/src/string/repr.rs create mode 100644 patch/rkyv-0.8.12/src/time.rs create mode 100644 patch/rkyv-0.8.12/src/traits.rs create mode 100644 patch/rkyv-0.8.12/src/tuple.rs create mode 100644 patch/rkyv-0.8.12/src/util/alloc/aligned_vec.rs create mode 100644 patch/rkyv-0.8.12/src/util/alloc/arena.rs create mode 100644 patch/rkyv-0.8.12/src/util/alloc/mod.rs create mode 100644 patch/rkyv-0.8.12/src/util/inline_vec.rs create mode 100644 patch/rkyv-0.8.12/src/util/mod.rs create mode 100644 patch/rkyv-0.8.12/src/util/ser_vec.rs create mode 100644 patch/rkyv-0.8.12/src/validation/archive/mod.rs create mode 100644 patch/rkyv-0.8.12/src/validation/archive/validator.rs create mode 100644 patch/rkyv-0.8.12/src/validation/mod.rs create mode 100644 patch/rkyv-0.8.12/src/validation/shared/mod.rs create mode 100644 patch/rkyv-0.8.12/src/validation/shared/validator.rs create mode 100644 patch/rkyv-0.8.12/src/vec.rs create mode 100644 patch/rkyv-0.8.12/src/with.rs create mode 100644 src/app/frontend.rs create mode 100644 src/app/lazy/path.rs create mode 100644 src/app/model/context_fill_mode.rs create mode 100644 src/app/model/default_instructions.rs create mode 100644 src/app/model/exchange_map.rs create mode 100644 src/app/model/id_source.rs create mode 100644 src/app/model/log/command.rs create mode 100644 src/app/model/log/limit.rs create mode 100644 src/app/model/log/manager.rs create mode 100644 src/app/model/log/storage.rs create mode 100644 src/app/model/state/token/queue.rs create mode 100644 src/app/model/version.rs create mode 100644 src/app/route.rs create mode 100644 src/common/model/ntp.rs create mode 100644 src/common/model/stringify.rs create mode 100644 src/common/model/userinfo/limit_type.rs create mode 100644 src/common/model/userinfo/membership_type.rs create mode 100644 src/common/model/userinfo/privacy_mode.rs create mode 100644 src/common/model/userinfo/usage_event.rs create mode 100644 src/common/model/userinfo/usage_info.rs create mode 100644 src/common/utils/proto_encode.rs create mode 100644 src/core/adapter/error.rs create mode 100644 src/core/auth.rs create mode 100644 src/core/auth/error.rs create mode 100644 src/core/auth/middleware.rs create mode 100644 src/core/auth/model.rs create mode 100644 src/core/auth/utils.rs create mode 100644 src/core/error/cpp.rs create mode 100644 src/core/model/tool_id_parser.rs create mode 100644 src/core/route/utils.rs delete mode 100644 tools/get-token/Cargo.toml delete mode 100644 tools/get-token/README.md delete mode 100644 tools/get-token/src/main.rs delete mode 100644 tools/next_reload/Cargo.toml delete mode 100644 tools/next_reload/src/main.rs delete mode 100644 tools/reset-telemetry/Cargo.toml delete mode 100644 tools/reset-telemetry/src/main.rs delete mode 100644 tools/rkyv_adapter/src/main.rs delete mode 100644 tools/set-token/Cargo.toml delete mode 100644 tools/set-token/src/main.rs diff --git a/.env.example b/.env.example index 69cbc4f..e6ef694 100644 --- a/.env.example +++ b/.env.example @@ -6,8 +6,8 @@ HOST= # 服务器监听端口 PORT=3000 -# 路由前缀,必须以 / 开头(如果不为空) -ROUTE_PREFIX= +# 路由前缀,必须以 / 开头(如果不为空)(已弃用,使用 route_registry.json 定义) +# ROUTE_PREFIX= # 最高权限的认证令牌,必填 AUTH_TOKEN= @@ -181,3 +181,79 @@ ALLOWED_PROVIDERS=auth0,google-oauth2,github # 绕过模型验证,允许所有模型(会带有一定的性能损失) BYPASS_MODEL_VALIDATION=false + +# 请求模型唯一标识符来源 +# - 可选值 +# - id +# - client_id +# - server_id +MODEL_ID_SOURCE=client_id + +# 上下文填充位 +# - 可选值 +# - context: 1 # 仅使用当前上下文 +# - repo_context: 2 # 仅使用仓库上下文 +# - context + repo_context: 3 # 当前上下文 + 仓库上下文 +# - mode_specific_context: 4 # 模式特定上下文 +# - context + mode_specific_context: 5 # 当前上下文 + 模式特定上下文 +# - repo_context + mode_specific_context: 6 # 仓库上下文 + 模式特定上下文 +# - all_contexts: 7 # 所有上下文组合 +CONTEXT_FILL_MODE=1 + +# 前端资源路径 +# 事实上定义了 route_registry.json +FRONTEND_PATH=frontend.zip + +# NTP 服务器列表(逗号分隔) +# 留空则完全禁用 NTP 时间同步功能 +# 示例:pool.ntp.org,time.cloudflare.com,time.windows.com +NTP_SERVERS= + +# NTP 周期性同步间隔(秒) +# 仅在配置了服务器时生效 +# 0 或不设置表示仅在启动时同步一次(不启动后台任务) +NTP_SYNC_INTERVAL_SECS=3600 + +# 每次同步的采样次数 +# 多次采样可提高精度,但会增加同步耗时 +# 可用最小值为 3 +NTP_SAMPLE_COUNT=8 + +# 采样间隔(毫秒) +# 两次采样之间的等待时间 +# 过小可能导致网络拥塞,过大会延长同步时间 +NTP_SAMPLE_INTERVAL_MS=50 + +# 预计的峰值速率 +# RPS: 每秒请求数 +LOG_PEAK_RPS=25 + +# 期望的缓冲时长 +# 不进行日志丢弃时期望的堵塞延迟时间 +LOG_BUFFER_SECONDS=2 + +# 过载时日志丢弃(未实现) +# LOG_DROP_ON_OVERLOAD=false + +# 运行时间的显示格式 +# 可选值: +# - auto : 自动选择格式 +# - compact : 紧凑格式 (如: 1h30m) +# - standard : 标准格式 (如: 1 hour 30 minutes) +# - detailed : 详细格式 (如: 1 hour, 30 minutes, 5 seconds) +# - iso8601 : ISO 8601 格式 (如: PT1H30M5S) +# - fuzzy : 模糊格式 (如: about an hour) +# - numeric : 纯数字格式 (如: 5405) +# - verbose : 冗长格式 (如: 1 hour and 30 minutes) +# - random : 随机格式 (仅用于测试) +DURATION_FORMAT=random + +# 运行时间的显示语言 +# 可选值: +# - english : 英语 +# - chinese : 中文 +# - japanese : 日语 +# - spanish : 西班牙语 +# - german : 德语 +# - random : 随机语言 (仅用于测试) +DURATION_LANGUAGE=random diff --git a/.gitignore b/.gitignore index 60607f1..c4399f5 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,9 @@ Cargo.lock /*.tar.gz /src/core/model/a.rs .cargo/config.toml +/front +/pmacro +/*.json +/tests/data/stream_data.txt +/frontend.zip +/config.toml diff --git a/.rust-toolchain.toml b/.rust-toolchain.toml new file mode 100644 index 0000000..5d56faf --- /dev/null +++ b/.rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" diff --git a/.rustfmt.toml b/.rustfmt.toml index e49d5f3..f83edaf 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,80 +1,16 @@ -max_width = 100 -hard_tabs = false -tab_spaces = 4 -newline_style = "Unix" -indent_style = "Block" -use_small_heuristics = "Max" -fn_call_width = 60 -attr_fn_like_width = 70 -struct_lit_width = 18 -struct_variant_width = 35 -array_width = 60 -chain_width = 60 -single_line_if_else_max_width = 50 -single_line_let_else_max_width = 50 -wrap_comments = false -format_code_in_doc_comments = false -doc_comment_code_block_width = 100 -comment_width = 80 -normalize_comments = false -normalize_doc_attributes = false -format_strings = false -format_macro_matchers = true -format_macro_bodies = true -skip_macro_invocations = [] -hex_literal_case = "Preserve" -empty_item_single_line = true -struct_lit_single_line = true -fn_single_line = true -where_single_line = false -imports_indent = "Block" -imports_layout = "Mixed" -imports_granularity = "Crate" -group_imports = "Preserve" -reorder_imports = true -reorder_modules = true -reorder_impl_items = false -type_punctuation_density = "Wide" -space_before_colon = false -space_after_colon = true -spaces_around_ranges = false -binop_separator = "Front" -remove_nested_parens = true -combine_control_expr = true -short_array_element_width_threshold = 10 -overflow_delimited_expr = true -struct_field_align_threshold = 0 -enum_discrim_align_threshold = 0 -match_arm_blocks = false -match_arm_leading_pipes = "Never" -force_multiline_blocks = false -fn_params_layout = "Tall" -brace_style = "SameLineWhere" -control_brace_style = "AlwaysSameLine" -trailing_semicolon = true -trailing_comma = "Vertical" -match_block_trailing_comma = false -blank_lines_upper_bound = 1 -blank_lines_lower_bound = 0 -edition = "2024" +# reorder_imports = false +# reorder_modules = false + style_edition = "2024" -# version = "One" -inline_attribute_width = 0 -format_generated_files = true -generated_marker_line_search_limit = 5 -merge_derives = true -use_try_shorthand = true +use_small_heuristics = "Max" +# merge_derives = false +# group_imports = "Preserve" +imports_granularity = "Crate" +group_imports = "One" +# imports_granularity = "One" use_field_init_shorthand = true -force_explicit_abi = true -condense_wildcard_suffixes = false -color = "Auto" -required_version = "1.8.0" -unstable_features = true -disable_all_formatting = false -skip_children = false -show_parse_errors = true -error_on_line_overflow = false -error_on_unformatted = false -ignore = [] -emit_mode = "Files" -make_backup = false + +#unstable_features = true + +fn_single_line = true +where_single_line = true diff --git a/Cargo.toml b/Cargo.toml index d9c2e50..3ee230c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,93 @@ -[package] -name = "cursor-api" -version = "0.3.6-2" +[workspace] +resolver = "2" +members = [ + ".", + "crates/manually_init", + "crates/rep_move", + "crates/grpc-stream", + "crates/interned", +] +default-members = ["."] + +[workspace.package] +version = "0.4.0-pre.14" edition = "2024" authors = ["wisdgod "] -description = "OpenAI format compatibility layer for the Cursor API" +description = "A format compatibility layer for the Cursor API" license = "MIT OR Apache-2.0" repository = "https://github.com/wisdgod/cursor-api" +[workspace.dependencies] +manually_init = { path = "crates/manually_init", features = ["sync"] } +rep_move = { path = "crates/rep_move" } +grpc-stream = { path = "crates/grpc-stream" } +interned = { path = "crates/interned" } + +# ===== 开发配置(编译速度优先)===== +[profile.dev] +opt-level = 0 +debug = "line-tables-only" +split-debuginfo = "unpacked" +incremental = true +codegen-units = 256 +rustflags = ["-Clink-arg=-fuse-ld=mold"] + +# ===== 快速测试构建(平衡配置)===== +[profile.fast] +inherits = "dev" +opt-level = 1 + +# ===== 性能测试配置(接近 release 但编译更快)===== +[profile.bench] +inherits = "release" +lto = "thin" +codegen-units = 16 +# rustflags = [ +# "-Ctarget-cpu=native", +# ] + +# ===== 发布配置(性能最大化)===== +[profile.release] +opt-level = 3 +lto = "fat" +codegen-units = 1 +panic = "abort" +strip = true +debug = false +overflow-checks = false +incremental = false +trim-paths = "all" +# rustflags = [ +# "-Clink-arg=-s", +# # "-Clink-arg=-fuse-ld=lld", +# # "-Ctarget-cpu=native", +# ] + +[patch.crates-io] +h2 = { path = "patch/h2-0.4.10" } +reqwest = { path = "patch/reqwest-0.12.18" } +rustls = { path = "patch/rustls-0.23.28" } +chrono = { path = "patch/chrono-0.4.42" } +ulid = { path = "patch/ulid-1.2.1" } +dotenvy = { path = "patch/dotenvy-0.15.7" } +# bs58 = { path = "patch/bs58-0.5.1" } +# base62 = { path = "patch/base62-2.2.1" } +prost = { path = "patch/prost-0.14.1" } +prost-derive = { path = "patch/prost-derive" } +prost-types = { path = "patch/prost-types" } +rkyv = { path = "patch/rkyv-0.8.12" } + +# =========================================== + +[package] +name = "cursor-api" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + [[bin]] name = "cursor-api" path = "src/main.rs" @@ -16,92 +97,147 @@ path = "src/main.rs" # path = "tools/rkyv_adapter/src/main.rs" [build-dependencies] -chrono = { version = "0.4", default-features = false, features = ["alloc"]} +chrono = { version = "0.4", default-features = false, features = ["alloc"] } prost-build = { version = "0.14", optional = true } -sha2 = { version = "0.10", default-features = false } +sha2 = { version = "0", default-features = false } serde_json = "1" [dependencies] -ahash = { version = "0.8", default-features = false, features = ["std", "compile-time-rng", "serde"] } +# owned +manually_init.workspace = true +rep_move.workspace = true +grpc-stream.workspace = true +interned.workspace = true + +ahash = { version = "0.8", default-features = false, features = [ + "compile-time-rng", + "serde", +] } arc-swap = "1" -axum = { version = "0.8", default-features = false, features = ["http1", "http2", "json", "tokio", "query", "macros"] } +axum = { version = "0.8", default-features = false, features = [ + "http1", + "http2", + "json", + "tokio", + "query", + "macros", +] } # base62 = "2.2.1" base64 = { version = "0.22", default-features = false, features = ["std"] } # bs58 = { version = "0.5.1", default-features = false, features = ["std"] } # brotli = { version = "7.0", default-features = false, features = ["std"] } -bytes = "1.10" -chrono = { version = "0.4", default-features = false, features = ["alloc", "serde", "rkyv-64"] } +bytes = "1" +chrono = { version = "0.4", default-features = false, features = [ + "alloc", + "serde", + "rkyv-64", +] } chrono-tz = { version = "0.10", features = ["serde"] } dotenvy = "0.15" -flate2 = { version = "1", default-features = false, features = ["rust_backend"] } +flate2 = { version = "1", default-features = false, features = [ + "rust_backend", +] } futures = { version = "0.3", default-features = false, features = ["std"] } -gif = { version = "0.13", default-features = false, features = ["std"] } -hashbrown = { version = "0.15", default-features = false } +gif = { version = "0.14", default-features = false, features = ["std"] } +hashbrown = { version = "0.16", default-features = false, features = [ + "serde", + "raw-entry", + "inline-more", +] } hex = { version = "0.4", default-features = false, features = ["std"] } http = "1" http-body-util = "0.1" -image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } +image = { version = "0.25", default-features = false, features = [ + "jpeg", + "png", + "gif", + "webp", +] } # lasso = { version = "0.7", features = ["multi-threaded", "ahasher"] } memmap2 = "0.9" +minicbor = { version = "2", features = ["derive", "alloc"] } # openssl = { version = "0.10", features = ["vendored"] } -parking_lot = "0.12" -paste = "1.0" -phf = { version = "0.12", features = ["macros"] } +parking_lot = { version = "0.12", features = [ + "arc_lock", + "hardware-lock-elision", +] } +paste = "1" +phf = { version = "0.13", features = ["macros"] } # pin-project-lite = "0.2" # pin-project = "1" -prost = "0.14" -prost-types = "0.14" +prost = { version = "0.14", features = ["indexmap"] } +# prost-types = "0.14" rand = { version = "0.9", default-features = false, features = ["thread_rng"] } -reqwest = { version = "0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "charset", "http2", "macos-system-configuration"] } -rkyv = { version = "0.8", default-features = false, features = ["std", "pointer_width_64", "uuid-1"] } +reqwest = { version = "0.12", default-features = false, features = [ + "gzip", + "brotli", + "json", + "stream", + "socks", + "charset", + "http2", + "system-proxy", +] } +rkyv = { version = "0.8", default-features = false, features = [ + "std", + "pointer_width_64", + "hashbrown-0_16", + "uuid-1", +] } # rustls = { version = "0.23.26", default-features = false, features = ["std", "tls12"] } -serde = { version = "1", default-features = false, features = ["std", "derive", "rc"] } +serde = { version = "1", default-features = false, features = [ + "std", + "derive", + "rc", +] } # serde_json = { package = "sonic-rs", version = "0" } -serde_json = "1" +serde_json = { version = "1", features = ["preserve_order"] } sha2 = { version = "0", default-features = false } sysinfo = { version = "0.37", default-features = false, features = ["system"] } -tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "sync", "time", "fs", "signal"] } +tokio = { version = "1", features = [ + "rt-multi-thread", + "macros", + "net", + "sync", + "time", + "fs", + "signal", +] } tokio-util = { version = "0.7", features = ["io"] } # tokio-tungstenite = { version = "0.26.2", features = ["rustls-tls-webpki-roots"] } # tokio-stream = { version = "0.1", features = ["time"] } tower-http = { version = "0.6", features = ["cors", "limit"] } -tracing = { version = "*", default-features = false, features = ["max_level_off", "release_max_level_off"] } +tracing = { version = "*", default-features = false, features = [ + "max_level_off", + "release_max_level_off", +] } ulid = { version = "1.2", default-features = false, features = ["std", "rkyv"] } # tracing-subscriber = "0.3" url = { version = "2.5", default-features = false, features = ["serde"] } -uuid = { version = "1.14", default-features = false, features = ["v4", "fast-rng", "serde"] } - -[profile.dev] -debug = "line-tables-only" - -[profile.release] -lto = true -codegen-units = 1 -panic = 'abort' -strip = true -# debug = true -# split-debuginfo = 'packed' -# strip = "none" -# panic = 'unwind' -opt-level = 3 -trim-paths = "all" -rustflags = ["-Cdebuginfo=0", "-Zthreads=8"] +uuid = { version = "1.14", default-features = false, features = [ + "v4", + "fast-rng", + "serde", +] } +zip = { version = "7", default-features = false, features = [ + "deflate", + "bzip2", + "zstd", + "deflate64", + "lzma", + "xz", +] } +indexmap = { version = "2", default-features = false, features = ["serde"] } +itoa = "1.0" [features] -default = ["webpki-roots"] +default = ["webpki-roots", "horizon"] webpki-roots = ["reqwest/rustls-tls-webpki-roots"] native-roots = ["reqwest/rustls-tls-native-roots"] use-minified = [] __preview = [] +__preview_locked = ["__preview"] __protoc = ["prost-build"] __compat = [] - -[patch.crates-io] -h2 = { path = "patch/h2-0.4.10" } -reqwest = { path = "patch/reqwest-0.12.18" } -rustls = { path = "patch/rustls-0.23.28" } -chrono = { path = "patch/chrono-0.4.41" } -ulid = { path = "patch/ulid-1.2.1" } -dotenvy = { path = "patch/dotenvy-0.15.7" } -# bs58 = { path = "patch/bs58-0.5.1" } -# base62 = { path = "patch/base62-2.2.1" } +horizon = ["nightly"] +nightly = ["hashbrown/nightly", "parking_lot/nightly"] diff --git a/Dockerfile b/Dockerfile index 7d00925..d0f6b17 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ WORKDIR /build RUN apt-get update && apt-get install -y --no-install-recommends gcc nodejs npm musl-tools && rm -rf /var/lib/apt/lists/* && case "$TARGETARCH" in amd64) rustup target add x86_64-unknown-linux-musl ;; arm64) rustup target add aarch64-unknown-linux-musl ;; *) echo "Unsupported architecture for rustup: $TARGETARCH" && exit 1 ;; esac COPY . . -RUN case "$TARGETARCH" in amd64) TARGET_TRIPLE="x86_64-unknown-linux-musl"; TARGET_CPU="x86-64-v3" ;; arm64) TARGET_TRIPLE="aarch64-unknown-linux-musl"; TARGET_CPU="neoverse-n1" ;; *) echo "Unsupported architecture: $TARGETARCH" && exit 1 ;; esac && RUSTFLAGS="-C link-arg=-s -C target-feature=+crt-static -C target-cpu=$TARGET_CPU -A unused" cargo build --bin cursor-api --release --target=$TARGET_TRIPLE && mkdir /app && cp target/$TARGET_TRIPLE/release/cursor-api /app/ +RUN case "$TARGETARCH" in amd64) TARGET_TRIPLE="x86_64-unknown-linux-musl"; TARGET_CPU="x86-64-v2" ;; arm64) TARGET_TRIPLE="aarch64-unknown-linux-musl"; TARGET_CPU="generic" ;; *) echo "Unsupported architecture: $TARGETARCH" && exit 1 ;; esac && RUSTFLAGS="-C link-arg=-s -C target-feature=+crt-static -C target-cpu=$TARGET_CPU -A unused" cargo build --bin cursor-api --release --target=$TARGET_TRIPLE && mkdir /app && cp target/$TARGET_TRIPLE/release/cursor-api /app/ # 运行阶段 FROM scratch diff --git a/LICENSE b/LICENSE deleted file mode 100644 index b4b17d7..0000000 --- a/LICENSE +++ /dev/null @@ -1,6 +0,0 @@ -This project is licensed under either of - - * Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) - * MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT) - -at your option. diff --git a/LICENSE-APACHE b/LICENSE-APACHE deleted file mode 100644 index d645695..0000000 --- a/LICENSE-APACHE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT deleted file mode 100644 index 03f2071..0000000 --- a/LICENSE-MIT +++ /dev/null @@ -1,7 +0,0 @@ -Copyright 2025 wisdgod - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..76d3ee2 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,20 @@ +This project is licensed under either of + + * Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT) + +at your option. + +### Additional Terms and Restrictions + +The following additional terms apply to all forks and derivative works: + +1. **Attribution Restrictions**: You may NOT use the name(s) of the original author(s), maintainer(s), or contributor(s) of this project for any promotional, marketing, or advertising purposes in connection with your fork or derivative work. + +2. **No Implied Endorsement**: Forks and derivative works must not imply or suggest that they are endorsed by, affiliated with, or approved by the original author(s) or maintainer(s) of this project. + +3. **Low-Profile Usage Preferred**: While not legally binding, users are encouraged to use this project and its derivatives in a low-profile manner, respecting the original author's preference for discretion. + +4. **Clear Distinction Required**: Any fork or derivative work must clearly indicate that it is a modified version and not the original project. + +These additional terms supplement and do not replace the chosen license (Apache 2.0 or MIT). By using, modifying, or distributing this project, you agree to comply with both the chosen license and these additional terms. diff --git a/README.md b/README.md index b14330a..bfd7268 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ * 属于官方的问题,请不要像作者反馈。 * 本程序拥有堪比客户端原本的速度,甚至可能更快。 * 本程序的性能是非常厉害的。 -* 根据本项目开源协议,Fork的项目不能以作者的名义进行任何形式的宣传、推广或声明。 -* 更新的时间跨度达5月有余了,求赞助,项目不收费,不定制。 +* 根据本项目开源协议,Fork的项目不能以作者的名义进行任何形式的宣传、推广或声明。原则上希望低调使用。 +* 更新的时间跨度达近10月了,求赞助,项目不收费,不定制。 * 推荐自部署,[官方网站](https://cc.wisdgod.com) 仅用于作者测试,不保证稳定性。 ## 获取key @@ -47,94 +47,7 @@ token2,checksum2 写死了,后续也不会会支持自定义模型列表,因为本身就支持动态更新,详见[更新模型列表说明](#更新模型列表说明) -``` -claude-4-sonnet -claude-4-sonnet-thinking -claude-4-opus-thinking -claude-4-opus -default -claude-3.5-sonnet -o3 -gemini-2.5-pro-preview-05-06 -gemini-2.5-flash-preview-04-17 -gpt-4.1 -claude-3.7-sonnet -claude-3.7-sonnet-thinking -cursor-small -claude-3.5-haiku -gemini-2.5-pro-exp-03-25 -gpt-4o -o4-mini -deepseek-r1 -deepseek-v3.1 -grok-3-beta -grok-3-mini -``` - -支持图像(default始终支持): -``` -claude-4-sonnet -claude-4-sonnet-thinking -claude-4-opus-thinking -claude-4-opus -claude-3.5-sonnet -o3 -gemini-2.5-pro-preview-05-06 -gemini-2.5-flash-preview-04-17 -gpt-4.1 -claude-3.7-sonnet -claude-3.7-sonnet-thinking -claude-3.5-haiku -gemini-2.5-pro-exp-03-25 -gpt-4o -o4-mini -``` - -支持思考: -``` -claude-4-sonnet-thinking -claude-4-opus-thinking -o3 -gemini-2.5-pro-preview-05-06 -gemini-2.5-flash-preview-04-17 -claude-3.7-sonnet-thinking -gemini-2.5-pro-exp-03-25 -o4-mini -deepseek-r1 -``` - -支持Max与非Max: -``` -claude-4-sonnet -claude-4-sonnet-thinking -claude-3.5-sonnet -gemini-2.5-pro-preview-05-06 -gpt-4.1 -claude-3.7-sonnet -claude-3.7-sonnet-thinking -gemini-2.5-pro-exp-03-25 -o4-mini -grok-3-beta -``` - -Max Only: -``` -claude-4-opus-thinking -claude-4-opus -o3 -``` - -非Max Only: -``` -default -gemini-2.5-flash-preview-04-17 -cursor-small -claude-3.5-haiku -gpt-4o -deepseek-r1 -deepseek-v3.1 -grok-3-mini -``` +打开程序自己看,以实际为准,这里不再赘述。 ## 接口说明 @@ -144,11 +57,9 @@ grok-3-mini * 请求方法: POST * 认证方式: Bearer Token 1. 使用环境变量 `AUTH_TOKEN` 进行认证 - 2. ~~使用 `.token` 文件中的令牌列表进行轮询认证~~ 在v0.1.3的rc版本更新中移除`.token`文件 - 3. ~~自v0.1.3-rc.3起支持直接使用 token,checksum 进行认证,但未提供配置关闭~~ v0.3.0起不再支持 - 4. 使用 `/build-key` 构建的动态密钥认证 - 5. 使用 `/config` 设置的共享Token进行认证 (关联:环境变量`SHARED_TOKEN`) - 6. 日志中的缓存 token key 的两种表示方式认证 (`/build-key` 同时也会给出这两种格式作为动态密钥的别名,该数字key本质为一个192位的整数) + 2. 使用 `/build-key` 构建的动态密钥认证 + 3. 使用 `/config` 设置的共享Token进行认证 (关联:环境变量`SHARED_TOKEN`) + 4. 日志中的缓存 token key 的两种表示方式认证 (`/build-key` 同时也会给出这两种格式作为动态密钥的别名,该数字key本质为一个192位的整数) #### 请求格式 @@ -157,7 +68,7 @@ grok-3-mini "model": string, "messages": [ { - "role": "system" | "user" | "assistant", // 也可以是 "developer" | "human" | "ai" + "role": "system" | "user" | "assistant", // "system" 也可以是 "developer" "content": string | [ { "type": "text" | "image_url", @@ -169,9 +80,9 @@ grok-3-mini ] } ], - "stream": boolean, + "stream": bool, "stream_options": { - "include_usage": boolean + "include_usage": bool } } ``` @@ -228,9 +139,9 @@ data: [DONE] ```json { - "is_nightly": boolean, // 是否包含 nightly 版本模型,默认 false - "include_long_context_models": boolean, // 是否包含长上下文模型,默认 false - "exclude_max_named_models": boolean, // 是否排除 max 命名的模型,默认 false + "is_nightly": bool, // 是否包含 nightly 版本模型,默认 false + "include_long_context_models": bool, // 是否包含长上下文模型,默认 false + "exclude_max_named_models": bool, // 是否排除 max 命名的模型,默认 true "additional_model_names": [string] // 额外包含的模型名称列表,默认空数组 } ``` @@ -239,22 +150,22 @@ data: [DONE] #### 响应格式 -```json +```typescript { - "object": "list", - "data": [ + object: "list", + data: [ { - "id": string, - "display_name": string, - "created": number, - "created_at": string, - "object": "model", - "type": "model", - "owned_by": string, - "supports_thinking": boolean, - "supports_images": boolean, - "supports_max_mode": boolean, - "supports_non_max_mode": boolean + id: string, + display_name: string, + created: number, + created_at: string, + object: "model", + type: "model", + owned_by: string, + supports_thinking: bool, + supports_images: bool, + supports_max_mode: bool, + supports_non_max_mode: bool } ] } @@ -266,13 +177,6 @@ data: [DONE] ### Token管理接口 -#### 简易Token信息管理页面 - -* 接口地址: `/tokens` -* 请求方法: GET -* 响应格式: HTML页面 -* 功能: 调用下面的各种相关API的示例页面 - #### 获取Token信息 * 接口地址: `/tokens/get` @@ -280,48 +184,117 @@ data: [DONE] * 认证方式: Bearer Token * 响应格式: -```json +```typescript { - "status": "success", - "tokens": [ + status: "success", + tokens: [ [ number, string, { - "bundle": { - "primary_token": string, - "secondary_token": string, // 可选 - "checksum": { - "first": string, - "second": string, - }, - "client_key": string, // 可选,非空时显示 - "config_version": string, // 可选 - "session_id": string, // 可选 - "proxy": string, // 可选 - "timezone": string, // 可选 - "gcpp_host": object, // 可选 - "user": { // 可选 - "email": string, - "name": string, - "updated_at": string, - "picture": string, // 可选 - "is_on_new_pricing": boolean + primary_token: string, + secondary_token?: string, + checksum: { + first: string, + second: string, + }, + client_key?: string, + config_version?: string, + session_id?: string, + proxy?: string, + timezone?: string, + gcpp_host?: "Asia" | "EU" | "US", + user?: { + user_id: int32, + email?: string, + first_name?: string, + last_name?: string, + workos_id?: string, + team_id?: number, + created_at?: string, + is_enterprise_user: bool, + is_on_new_pricing: bool, + privacy_mode_info: { + privacy_mode: "unspecified" | "no_storage" | "no_training" | "usage_data_training_allowed" | "usage_codebase_training_allowed", + is_enforced_by_team?: bool } }, - "status": "enabled" | "disabled", - "stripe": { // 可选 - "membership_type": "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", - "payment_id": string, // 可选 - "days_remaining_on_trial": number, - "subscription_status": "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", // 可选 - "verified_student": boolean, // 可选 - "is_on_student_plan": boolean // 可选 - } + status: { + enabled: bool + }, + usage?: { + billing_cycle_start: string, + billing_cycle_end: string, + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + limit_type: "user" | "team", + is_unlimited: bool, + individual_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + team_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + }, + stripe?: { + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + payment_id?: string, + days_remaining_on_trial: int32, + subscription_status?: "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", + verified_student: bool, + trial_eligible: bool, + trial_length_days: int32, + is_on_student_plan: bool, + is_on_billable_auto: bool, + customer_balance?: double, + trial_was_cancelled: bool, + is_team_member: bool, + team_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + individual_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise" + }, + sessions?: [ + { + session_id: string, + type: "unspecified" | "web" | "client" | "bugbot" | "background_agent", + created_at: string, + expires_at: string + } + ] } ] ], - "tokens_count": number + tokens_count: uint64 } ``` @@ -332,39 +305,101 @@ data: [DONE] * 认证方式: Bearer Token * 请求格式: -```json +```typescript [ [ string, { - "bundle": { - "primary_token": string, - "secondary_token": string, // 可选 - "checksum": { - "first": string, - "second": string, - }, - "client_key": string, // 可选 - "config_version": string, // 可选 - "session_id": string, // 可选 - "proxy": string, // 可选 - "timezone": string, // 可选 - "gcpp_host": object, // 可选 - "user": { // 可选 - "email": string, - "name": string, - "updated_at": string, - "picture": string, // 可选 - "is_on_new_pricing": boolean + primary_token: string, + secondary_token?: string, + checksum: { + first: string, + second: string, + }, + client_key?: string, + config_version?: string, + session_id?: string, + proxy?: string, + timezone?: string, + gcpp_host?: "Asia" | "EU" | "US", + user?: { + user_id: int32, + email?: string, + first_name?: string, + last_name?: string, + workos_id?: string, + team_id?: number, + created_at?: string, + is_enterprise_user: bool, + is_on_new_pricing: bool, + privacy_mode_info: { + privacy_mode: "unspecified" | "no_storage" | "no_training" | "usage_data_training_allowed" | "usage_codebase_training_allowed", + is_enforced_by_team?: bool } }, - "status": "enabled" | "disabled", - "stripe": { // 可选 - "membership_type": "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", - "payment_id": string, // 可选 - "days_remaining_on_trial": number, - "subscription_status": "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", // 可选 - "verified_student": boolean // 可选 + status: { + enabled: bool + }, + usage?: { + billing_cycle_start: string, + billing_cycle_end: string, + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + limit_type: "user" | "team", + is_unlimited: bool, + individual_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + team_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + }, + stripe?: { + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + payment_id?: string, + days_remaining_on_trial: int32, + subscription_status?: "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", + verified_student: bool, + trial_eligible: bool, + trial_length_days: int32, + is_on_student_plan: bool, + is_on_billable_auto: bool, + customer_balance?: double, + trial_was_cancelled: bool, + is_team_member: bool, + team_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + individual_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise" } } ] @@ -373,11 +408,11 @@ data: [DONE] * 响应格式: -```json +```typescript { - "status": "success", - "tokens_count": number, - "message": "Token files have been updated and reloaded" + status: "success", + tokens_count: uint64, + message: "Token files have been updated and reloaded" } ``` @@ -388,32 +423,32 @@ data: [DONE] * 认证方式: Bearer Token * 请求格式: -```json +```typescript { - "tokens": [ + tokens: [ { - "alias": string, // 可选,无则自动生成 - "token": string, - "checksum": string, // 可选,无则自动生成 - "client_key": string, // 可选,无则自动生成 - "session_id": string, // 可选 - "config_version": string, // 可选 - "proxy": string, // 可选 - "timezone": string, // 可选 - "gcpp_host": string // 可选 + alias?: string, // 可选,无则自动生成 + token: string, + checksum?: string, // 可选,无则自动生成 + client_key?: string, // 可选,无则自动生成 + session_id?: string, // 可选 + config_version?: string, // 可选 + proxy?: string, // 可选 + timezone?: string, // 可选 + gcpp_host?: string // 可选 } ], - "status": "enabled" | "disabled" + enabled: bool } ``` * 响应格式: -```json +```typescript { - "status": "success", - "tokens_count": number, - "message": string // "New tokens have been added and reloaded" 或 "No new tokens were added" + status: "success", + tokens_count: uint64, + message: string // "New tokens have been added and reloaded" 或 "No new tokens were added" } ``` @@ -427,7 +462,7 @@ data: [DONE] ```json { "aliases": [string], // 要删除的token列表 - "include_failed_tokens": boolean // 默认为false + "include_failed_tokens": bool // 默认为false } ``` @@ -493,7 +528,7 @@ data: [DONE] } ``` -#### 更新令牌CV +#### 更新令牌Config Version * 接口地址: `/tokens/config-version/update` * 请求方法: POST @@ -544,10 +579,10 @@ data: [DONE] * 认证方式: Bearer Token * 请求格式: -```json +```typescript { "aliases": [string], - "status": "enabled" | "disabled" + "enabled": bool } ``` @@ -628,6 +663,41 @@ data: [DONE] } ``` +#### 合并Tokens附带数据 + +* 接口地址: `/tokens/merge` +* 请求方法: POST +* 认证方式: Bearer Token +* 请求格式: + +```json +{ + "{alias}": { // 以下至少其一存在,否则会失败 + "primary_token": string, // 可选 + "secondary_token": string, // 可选 + "checksum": { // 可选 + "first": string, + "second": string, + }, + "client_key": string, // 可选 + "config_version": string, // 可选 + "session_id": string, // 可选 + "proxy": string, // 可选 + "timezone": string, // 可选 + "gcpp_host": object, // 可选 + } +} +``` + +* 响应格式: + +```json +{ + "status": "success", + "message": "已合并{}个令牌, {}个令牌合并失败" +} +``` + #### 构建API Key * 接口地址: `/build-key` @@ -649,9 +719,9 @@ data: [DONE] "proxy_name": string, // 可选,指定代理 "timezone": string, // 可选,指定时区 "gcpp_host": string, // 可选,代码补全区域 - "disable_vision": boolean, // 可选,禁用图片处理能力 - "enable_slow_pool": boolean, // 可选,启用慢速池 - "include_web_references": boolean, + "disable_vision": bool, // 可选,禁用图片处理能力 + "enable_slow_pool": bool, // 可选,启用慢速池 + "include_web_references": bool, "usage_check_models": { // 可选,使用量检查模型配置 "type": "default" | "disabled" | "all" | "custom", "model_ids": string // 当type为custom时生效,以逗号分隔的模型ID列表 @@ -694,53 +764,8 @@ data: [DONE] 5. 数字key是一个128位无符号整数与一个64位无符号整数组成的,比通常使用的uuid更难破解。 -#### 获取Config Version - -* 接口地址: `/config-version` -* 请求方法: POST -* 认证方式: Bearer Token (当SHARE_AUTH_TOKEN启用时需要) -* 请求格式: - -```json -{ - "token": string, // 格式: JWT - "checksum": { - "first": string, // 格式: 长度为64的Hex编码字符串 - "second": string, // 格式: 长度为64的Hex编码字符串 - }, - "client_key": string, // 格式: 长度为64的Hex编码字符串 - "session_id": string, // 格式: UUID - "proxy_name": string, // 可选,指定代理 - "timezone": string, // 可选,指定时区 - "gcpp_host": string // 可选,代码补全区域 -} -``` - -* 响应格式: - -```json -{ - "config_version": string // 成功时返回生成的UUID -} -``` - -或出错时: - -```json -{ - "error": string // 错误信息 -} -``` - ### 代理管理接口 -#### 简易代理信息管理页面 - -* 接口地址: `/proxies` -* 请求方法: GET -* 响应格式: HTML页面 -* 功能: 调用下面的各种相关API的示例页面 - #### 获取代理配置信息 * 接口地址: `/proxies/get` @@ -891,13 +916,6 @@ data: [DONE] ### 配置管理接口 -#### 配置页面 - -* 接口地址: `/config` -* 请求方法: GET -* 响应格式: HTML页面 -* 功能: 提供配置管理界面,可以修改页面内容和系统配置 - #### 更新配置 * 接口地址: `/config` @@ -914,16 +932,16 @@ data: [DONE] "value": string // type=redirect时为URL, type=plain_text/html/css/js时为对应内容 }, "vision_ability": "none" | "base64" | "all", // "disabled" | "base64-only" | "base64-http" - "enable_slow_pool": boolean, - "enable_long_context": boolean, + "enable_slow_pool": bool, + "enable_long_context": bool, "usage_check_models": { "type": "none" | "default" | "all" | "list", "content": string }, - "enable_dynamic_key": boolean, + "enable_dynamic_key": bool, "share_token": string, "calibrate_token": string, - "include_web_references": boolean + "include_web_references": bool } ``` @@ -939,16 +957,16 @@ data: [DONE] "value": string }, "vision_ability": "none" | "base64" | "all", - "enable_slow_pool": boolean, - "enable_long_context": boolean, + "enable_slow_pool": bool, + "enable_long_context": bool, "usage_check_models": { "type": "none" | "default" | "all" | "list", "content": string }, - "enable_dynamic_key": boolean, + "enable_dynamic_key": bool, "share_token": string, "calibrate_token": string, - "include_web_references": boolean + "include_web_references": bool } } ``` @@ -972,96 +990,152 @@ data: [DONE] * 认证方式: Bearer Token * 请求格式: -```json +```typescript { "query": { // 分页与排序控制 - "limit": number, // 可选,返回记录数量限制 - "offset": number, // 可选,起始位置偏移量 - "reverse": boolean, // 可选,反向排序,默认false(从旧到新),true时从新到旧 + "limit": number, // 可选,返回记录数量限制 + "offset": number, // 可选,起始位置偏移量 + "reverse": bool, // 可选,反向排序,默认false(从旧到新),true时从新到旧 // 时间范围过滤 - "from_date": string, // 可选,开始日期时间,RFC3339格式 - "to_date": string, // 可选,结束日期时间,RFC3339格式 + "from_date": string, // 可选,开始日期时间,RFC3339格式 + "to_date": string, // 可选,结束日期时间,RFC3339格式 // 用户标识过滤 - "user_id": string, // 可选,按用户ID精确匹配 - "email": string, // 可选,按用户邮箱过滤(支持部分匹配) - "membership_type": string, // 可选,按会员类型过滤 ("free"/"free_trial"/"pro"/"enterprise") + "user_id": string, // 可选,按用户ID精确匹配 + "email": string, // 可选,按用户邮箱过滤(支持部分匹配) + "membership_type": string, // 可选,按会员类型过滤 ("free"/"free_trial"/"pro"/"pro_plus"/"ultra"/"enterprise") // 核心业务过滤 - "status": string, // 可选,按状态过滤 ("pending"/"success"/"failure") - "model": string, // 可选,按模型名称过滤(支持部分匹配) - "include_models": [string], // 可选,包含特定模型 - "exclude_models": [string], // 可选,排除特定模型 + "status": string, // 可选,按状态过滤 ("pending"/"success"/"failure") + "model": string, // 可选,按模型名称过滤(支持部分匹配) + "include_models": [string], // 可选,包含特定模型 + "exclude_models": [string], // 可选,排除特定模型 // 请求特征过滤 - "stream": boolean, // 可选,是否为流式请求 - "has_chain": boolean, // 可选,是否包含对话链 - "has_usage": boolean, // 可选,是否有usage信息 + "stream": bool, // 可选,是否为流式请求 + "has_chain": bool, // 可选,是否包含对话链 + "has_usage": bool, // 可选,是否有usage信息 // 错误相关过滤 - "has_error": boolean, // 可选,是否包含错误 - "error": string, // 可选,按错误过滤(支持部分匹配) + "has_error": bool, // 可选,是否包含错误 + "error": string, // 可选,按错误过滤(支持部分匹配) // 性能指标过滤 - "min_total_time": number, // 可选,最小总耗时(秒) - "max_total_time": number, // 可选,最大总耗时(秒) - "min_tokens": number, // 可选,最小token数(input + output) - "max_tokens": number // 可选,最大token数 + "min_total_time": number, // 可选,最小总耗时(秒) + "max_total_time": number, // 可选,最大总耗时(秒) + "min_tokens": number, // 可选,最小token数(input + output) + "max_tokens": number // 可选,最大token数 } } ``` * 响应格式: -```json +```typescript { - "total": number, - "logs": [ + status: "success", + total: uint64, + active?: uint64, + error?: uint64, + logs: [ { - "id": number, - "timestamp": string, - "model": string, - "token_info": { - "key": string, - "stripe": { // 可选 - "membership_type": "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", - "payment_id": string, // 可选 - "days_remaining_on_trial": number, - "subscription_status": "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", // 可选 - "verified_student": boolean // 可选 + id: uint64, + timestamp: string, + model: string, + token_info: { + key: string, + usage?: { + billing_cycle_start: string, + billing_cycle_end: string, + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + limit_type: "user" | "team", + is_unlimited: bool, + individual_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + team_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + }, + stripe?: { + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + payment_id?: string, + days_remaining_on_trial: int32, + subscription_status?: "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", + verified_student: bool, + trial_eligible: bool, + trial_length_days: int32, + is_on_student_plan: bool, + is_on_billable_auto: bool, + customer_balance?: double, + trial_was_cancelled: bool, + is_team_member: bool, + team_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + individual_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise" } }, - "chain": { - "prompt": [ // array or string - { - "role": string, - "content": string - } - ], - "delays": [ + chain: { + delays?: [ string, [ number, // chars count number // time ] ], - "usage": { // optional - "input": number, - "output": number, + usage?: { + input: int32, + output: int32, + cache_write: int32, + cache_read: int32, + cents: float } }, - "timing": { - "total": number + timing: { + total: double }, - "stream": boolean, - "status": string, - "error": string + stream: bool, + status: "pending" | "success" | "failure", + error?: string | { + error:string, + details:string + } } ], - "timestamp": string, - "status": "success" + timestamp: string } ``` @@ -1079,7 +1153,7 @@ data: [DONE] * 认证方式: Bearer Token * 请求格式: -```json +```typescript [ string ] @@ -1087,94 +1161,68 @@ data: [DONE] * 响应格式: -```json +```typescript { - "status": "success", - "tokens": { - "{key}": { - "primary_token": string, - "secondary_token": string, // 可选 - "checksum": { - "first": string, - "second": string, + status: "success", + tokens: { + {key}: { + primary_token: string, + secondary_token?: string, + checksum: { + first: string, + second: string, }, - "client_key": string, // 可选,非空时显示 - "config_version": string, // 可选 - "session_id": string, // 可选 - "proxy": string, // 可选 - "timezone": string, // 可选 - "gcpp_host": object, // 可选 - "user": { // 可选 - "email": string, - "name": string, - "updated_at": string, - "picture": string, // 可选 - "is_on_new_pricing": boolean + client_key?: string, + config_version?: string, + session_id?: string, + proxy?: string, + timezone?: string, + gcpp_host?: "Asia" | "EU" | "US", + user?: { + user_id: int32, + email?: string, + first_name?: string, + last_name?: string, + workos_id?: string, + team_id?: number, + created_at?: string, + is_enterprise_user: bool, + is_on_new_pricing: bool, + privacy_mode_info: { + privacy_mode: "unspecified" | "no_storage" | "no_training" | "usage_data_training_allowed" | "usage_codebase_training_allowed", + is_enforced_by_team?: bool + } } } }, - "total": number, - "timestamp": string + total: uint64, + timestamp: string } ``` ### 静态资源接口 -#### 获取共享样式 - -* 接口地址: `/static/shared-styles.css` -* 请求方法: GET -* 响应格式: CSS文件 -* 功能: 获取共享样式表 - -#### 获取共享脚本 - -* 接口地址: `/static/shared.js` -* 请求方法: GET -* 响应格式: JavaScript文件 -* 功能: 获取共享JavaScript代码 - -#### 获取其他资源 - -* 接口地址: `/static/{path}` -* 请求方法: GET -* 请求参数: - - `path`: 静态文件的相对路径 - -* 响应格式: - - **成功响应 (200 OK)**: - - Headers: - - `Content-Type`: 根据文件扩展名自动设置(见下方MIME类型映射表) - - `Content-Length`: 文件大小 - - Body: 文件的二进制内容 - - - **文件不存在或大小超过4GiB (404 Not Found)**: - - Headers: - - `Content-Type`: `text/plain; charset=utf-8` - - Body: 错误信息 - -* 支持的MIME类型映射: - - 文本类型: html, htm, txt, css, js, mjs, csv, xml, md, markdown - - 图像类型: jpg, jpeg, png, gif, webp, svg, bmp, ico, tiff, tif, avif - - 音频类型: mp3, mp4a, wav, ogg, oga, weba, aac, flac, m4a - - 视频类型: mp4, mpeg, mpg, webm, ogv, avi, mov, qt, flv - - 文档类型: pdf, doc, docx, xls, xlsx, ppt, pptx - - 压缩文件: zip, rar, 7z, gz, gzip, tar - - 字体类型: ttf, otf, woff, woff2 - - 其他类型: 默认为 `application/octet-stream` - -* 功能: 获取从环境变量DATA_DIR指定的目录下的子目录static下的文件。 - #### 环境变量示例 * 接口地址: `/env-example` * 请求方法: GET -* 响应格式: 文本文件 -* 功能: 获取环境变量配置示例 +* 响应格式: 文本 + +#### 文档 + +* 接口地址: `/readme` +* 请求方法: GET +* 响应格式: HTML + +#### 许可 + +* 接口地址: `/license` +* 请求方法: GET +* 响应格式: HTML ### 健康检查接口 -* **接口地址**: `/health` 或 `/`(重定向) +* **接口地址**: `/health` * **请求方法**: GET * **认证方式**: 无需 * **响应格式**: 根据配置返回不同的内容类型(默认JSON、文本或HTML) @@ -1217,8 +1265,8 @@ data: [DONE] }, "capabilities": { "models": ["gpt-4", "claude-3"], - "endpoints": ["/chat", "/completions", "/embeddings"], - "features": ["streaming", "function_calling", "vision"] + "endpoints": ["/v1/chat/completions", "/v1/messages"], + "features": [".."] } } ``` @@ -1230,10 +1278,10 @@ data: [DONE] | `status` | string | 服务状态: "success", "warning", "error" | | `service.name` | string | 服务名称 | | `service.version` | string | 服务版本 | -| `service.is_debug` | boolean | 是否为调试模式 | +| `service.is_debug` | bool | 是否为调试模式 | | `service.build.version` | number | 构建版本号(仅preview功能启用时) | | `service.build.timestamp` | string | 构建时间戳 | -| `service.build.is_prerelease` | boolean | 是否为预发布版本 | +| `service.build.is_prerelease` | bool | 是否为预发布版本 | | `runtime.started_at` | string | 服务启动时间 | | `runtime.uptime_seconds` | number | 运行时长(秒) | | `runtime.requests.total` | number | 总请求数 | @@ -1290,9 +1338,9 @@ string string ``` -#### 获取当前的timestampheader +#### 获取当前的checksum header -* 接口地址: `/get-timestamp-header` +* 接口地址: `/get-checksum-header` * 请求方法: GET * 响应格式: @@ -1300,51 +1348,112 @@ string string ``` -#### 获取用户信息(已弃用) +#### 获取账号信息 -* 接口地址: `/userinfo` +* 接口地址: `/token-profile/get` * 请求方法: POST -* 认证方式: 请求体中包含token +* 认证方式: Bearer Token * 请求格式: -```json +```typescript { - "token": string + session_token: string, + web_token: string, + proxy_name?: string, + include_sessions: bool } ``` * 响应格式: -```json +```typescript { - "usage": { - "premium": { - "num_requests": number, - "total_requests": number, - "num_tokens": number, - "max_requests": number, - "max_tokens": number + token_profile: [ + null | { + billing_cycle_start: string, + billing_cycle_end: string, + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + limit_type: "user" | "team", + is_unlimited: bool, + individual_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, + team_usage: { + plan?: { + enabled: bool, + used: int32, + limit: int32, + remaining: int32, + breakdown: { + included: int32, + bonus: int32, + total: int32 + } + }, + on_demand?: { + enabled: bool, + used: int32, + limit?: int32, + remaining?: int32 + } + }, }, - "standard": { - "num_requests": number, - "total_requests": number, - "num_tokens": number, - "max_requests": number, - "max_tokens": number + null | { + membership_type: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + payment_id?: string, + days_remaining_on_trial: int32, + subscription_status?: "trialing" | "active" | "incomplete" | "incomplete_expired" | "past_due" | "canceled" | "unpaid" | "paused", + verified_student: bool, + trial_eligible: bool, + trial_length_days: int32, + is_on_student_plan: bool, + is_on_billable_auto: bool, + customer_balance?: double, + trial_was_cancelled: bool, + is_team_member: bool, + team_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise", + individual_membership_type?: "free" | "free_trial" | "pro" | "pro_plus" | "ultra" | "enterprise" }, - "start_of_month": string - }, - "user": { - "email": string, - "name": string, - "id": string, - "updated_at": string - }, - "stripe": { - "membership_type": "free" | "free_trial" | "pro" | "enterprise", - "payment_id": string, - "days_remaining_on_trial": number - } + null | { + user_id: int32, + email?: string, + first_name?: string, + last_name?: string, + workos_id?: string, + team_id?: number, + created_at?: string, + is_enterprise_user: bool, + is_on_new_pricing: bool, + privacy_mode_info: { + privacy_mode: "unspecified" | "no_storage" | "no_training" | "usage_data_training_allowed" | "usage_codebase_training_allowed", + is_enforced_by_team?: bool + } + }, + null | [ + { + session_id: string, + type: "unspecified" | "web" | "client" | "bugbot" | "background_agent", + created_at: string, + expires_at: string + } + ] + ] } ``` @@ -1356,6 +1465,44 @@ string } ``` +#### 获取Config Version + +* 接口地址: `/config-version/get` +* 请求方法: POST +* 认证方式: Bearer Token (当SHARE_AUTH_TOKEN启用时需要) +* 请求格式: + +```json +{ + "token": string, // 格式: JWT + "checksum": { + "first": string, // 格式: 长度为64的Hex编码字符串 + "second": string, // 格式: 长度为64的Hex编码字符串 + }, + "client_key": string, // 格式: 长度为64的Hex编码字符串 + "session_id": string, // 格式: UUID + "proxy_name": string, // 可选,指定代理 + "timezone": string, // 可选,指定时区 + "gcpp_host": string // 可选,代码补全区域 +} +``` + +* 响应格式: + +```json +{ + "config_version": string // 成功时返回生成的UUID +} +``` + +或出错时: + +```json +{ + "error": string // 错误信息 +} +``` + #### 获取更新令牌(已弃用) * 接口地址: `/token-upgrade` @@ -1379,32 +1526,6 @@ string } ``` -#### 基础校准(已弃用) - -* 接口地址: `/basic-calibration` -* 请求方法: POST -* 认证方式: 请求体中包含token -* 请求格式: - -```json -{ - "token": string -} -``` - -* 响应格式: - -```json -{ - "status": "success" | "error", - "message": string, - "user_id": string, - "create_at": string -} -``` - -注意: `user_id` 和 `create_at` 字段在校验失败时可能不存在。 - ## Copilot++ 接口文档 1. 相关接口都需要 `x-client-key`, 格式请见 `/gen-hash`(32字节)。 @@ -1421,9 +1542,9 @@ string ```json { - "is_nightly": boolean, // 可选,是否使用nightly版本 + "is_nightly": bool, // 可选,是否使用nightly版本 "model": string, // 模型名称 - "supports_cpt": boolean // 可选,是否支持CPT + "supports_cpt": bool // 可选,是否支持CPT } ``` @@ -1438,9 +1559,9 @@ string "limit": number, // 可选,限制 "radius": number // 可选,半径 }, - "is_on": boolean, // 可选,是否开启 - "is_ghost_text": boolean, // 可选,是否使用幽灵文本 - "should_let_user_enable_cpp_even_if_not_pro": boolean, // 可选,非专业用户是否可以启用 + "is_on": bool, // 可选,是否开启 + "is_ghost_text": bool, // 可选,是否使用幽灵文本 + "should_let_user_enable_cpp_even_if_not_pro": bool, // 可选,非专业用户是否可以启用 "heuristics": [ // 启用的启发式规则列表 "lots_of_added_text", "duplicating_line_after_suggestion", @@ -1450,26 +1571,26 @@ string "suggesting_recently_rejected_edit" ], "exclude_recently_viewed_files_patterns": [string], // 最近查看文件排除模式 - "enable_rvf_tracking": boolean, // 是否启用RVF跟踪 + "enable_rvf_tracking": bool, // 是否启用RVF跟踪 "global_debounce_duration_millis": number, // 全局去抖动时间(毫秒) "client_debounce_duration_millis": number, // 客户端去抖动时间(毫秒) "cpp_url": string, // CPP服务URL - "use_whitespace_diff_history": boolean, // 是否使用空白差异历史 + "use_whitespace_diff_history": bool, // 是否使用空白差异历史 "import_prediction_config": { // 导入预测配置 - "is_disabled_by_backend": boolean, // 是否被后端禁用 - "should_turn_on_automatically": boolean, // 是否自动开启 - "python_enabled": boolean // Python是否启用 + "is_disabled_by_backend": bool, // 是否被后端禁用 + "should_turn_on_automatically": bool, // 是否自动开启 + "python_enabled": bool // Python是否启用 }, - "enable_filesync_debounce_skipping": boolean, // 是否启用文件同步去抖动跳过 + "enable_filesync_debounce_skipping": bool, // 是否启用文件同步去抖动跳过 "check_filesync_hash_percent": number, // 文件同步哈希检查百分比 "geo_cpp_backend_url": string, // 地理位置CPP后端URL "recently_rejected_edit_thresholds": { // 可选,最近拒绝编辑阈值 "hard_reject_threshold": number, // 硬拒绝阈值 "soft_reject_threshold": number // 软拒绝阈值 }, - "is_fused_cursor_prediction_model": boolean, // 是否使用融合光标预测模型 - "include_unchanged_lines": boolean, // 是否包含未更改行 - "should_fetch_rvf_text": boolean, // 是否获取RVF文本 + "is_fused_cursor_prediction_model": bool, // 是否使用融合光标预测模型 + "include_unchanged_lines": bool, // 是否包含未更改行 + "should_fetch_rvf_text": bool, // 是否获取RVF文本 "max_number_of_cleared_suggestions_since_last_accept": number, // 可选,上次接受后清除建议的最大数量 "suggestion_hint_config": { // 可选,建议提示配置 "important_lsp_extensions": [string], // 重要的LSP扩展 @@ -1569,364 +1690,516 @@ string #### 请求格式 -```json +```typescript { - "current_file": { // 当前文件信息 - "relative_workspace_path": string, // 文件相对于工作区的路径 - "contents": string, // 文件内容 - "rely_on_filesync": boolean, // 是否依赖文件同步 - "sha256_hash": string, // 可选,SHA256哈希值 - "top_chunks": [ // 顶级代码块 + current_file: { // 当前文件信息 + relative_workspace_path: string, // 文件相对于工作区的路径 + contents: string, // 文件内容 + rely_on_filesync: bool, // 是否依赖文件同步 + sha_256_hash?: string, // 可选,文件内容SHA256哈希值 + top_chunks: [ // BM25检索的顶级代码块 { - "content": string, // 内容 - "range": { // 最简单范围 - "start_line": number, // 开始行 - "end_line_inclusive": number // 结束行(包含) + content: string, // 代码块内容 + range: { // SimplestRange 最简单范围 + start_line: int32, // 开始行号 + end_line_inclusive: int32 // 结束行号(包含) }, - "score": number, // 分数 - "relative_path": string // 相对路径 + score: int32, // BM25分数 + relative_path: string // 代码块所在文件相对路径 } ], - "contents_start_at_line": number, // 内容开始行 - "cursor_position": { // 光标位置 - "line": number, // 行号 - "column": number // 列号 + contents_start_at_line: int32, // 内容开始行号(一般为0) + cursor_position: { // CursorPosition 光标位置 + line: int32, // 行号(0-based) + column: int32 // 列号(0-based) }, - "dataframes": [ // 数据框信息 + dataframes: [ // DataframeInfo 数据框信息(用于数据分析场景) { - "name": string, // 名称 - "shape": string, // 形状 - "data_dimensionality": number, // 数据维度 - "columns": [ // 列 + name: string, // 数据框变量名 + shape: string, // 形状描述,如"(100, 5)" + data_dimensionality: int32, // 数据维度 + columns: [ // 列定义 { - "key": string, // 键 - "type": string // 类型 + key: string, // 列名 + type: string // 列数据类型 } ], - "row_count": number, // 行数 - "index_column": string // 索引列 + row_count: int32, // 行数 + index_column: string // 索引列名称 } ], - "total_number_of_lines": number, // 总行数 - "language_id": string, // 语言ID - "selection": { // 选择范围 - "start_position": { // 开始位置 - "line": number, // 行号 - "column": number // 列号 + total_number_of_lines: int32, // 文件总行数 + language_id: string, // 语言标识符(如"python", "rust") + selection?: { // 可选,CursorRange 当前选中范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 }, - "end_position": { // 结束位置 - "line": number, // 行号 - "column": number // 列号 + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 } }, - "alternative_version_id": number, // 可选,替代版本ID - "diagnostics": [ // 诊断信息 + alternative_version_id?: int32, // 可选,备选版本ID + diagnostics: [ // Diagnostic 诊断信息数组 { - "message": string, // 消息 - "range": { // 范围 - "start_position": { // 开始位置 - "line": number, // 行号 - "column": number // 列号 + message: string, // 诊断消息内容 + range: { // CursorRange 诊断范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 }, - "end_position": { // 结束位置 - "line": number, // 行号 - "column": number // 列号 + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 } }, - "severity": "error" | "warning" | "information" | "hint", // 严重程度 - "related_information": [ // 相关信息 + severity: "error" | "warning" | "information" | "hint", // DiagnosticSeverity 严重程度 + related_information: [ // RelatedInformation 相关信息 { - "message": string, // 消息 - "range": { // 范围 - "start_position": { // 开始位置 - "line": number, // 行号 - "column": number // 列号 + message: string, // 相关信息消息 + range: { // CursorRange 相关信息范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 }, - "end_position": { // 结束位置 - "line": number, // 行号 - "column": number // 列号 + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 } } } ] } ], - "file_version": number, // 可选,文件版本 - "cell_start_lines": [number], // 单元格开始行 - "workspace_root_path": string // 工作区根路径 + file_version?: int32, // 可选,文件版本号(用于增量更新) + workspace_root_path: string, // 工作区根路径(绝对路径) + line_ending?: string, // 可选,行结束符("\n" 或 "\r\n") + file_git_context: { // FileGit Git上下文信息 + commits: [ // GitCommit 相关提交数组 + { + commit: string, // 提交哈希 + author: string, // 作者 + date: string, // 提交日期 + message: string // 提交消息 + } + ] + } }, - "diff_history": [string], // 差异历史 - "model_name": string, // 可选,模型名称 - "linter_errors": { // 可选,Linter错误 - "relative_workspace_path": string, // 文件相对于工作区的路径 - "errors": [ // 错误数组 + diff_history: [string], // 差异历史(已弃用,使用file_diff_histories代替) + model_name?: string, // 可选,指定使用的模型名称 + linter_errors?: { // 可选,LinterErrors Linter错误信息 + relative_workspace_path: string, // 错误所在文件相对路径 + errors: [ // LinterError 错误数组 { - "message": string, // 错误消息 - "range": { // 范围 - "start_position": { // 开始位置 - "line": number, // 行号 - "column": number // 列号 + message: string, // 错误消息 + range: { // CursorRange 错误范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 }, - "end_position": { // 结束位置 - "line": number, // 行号 - "column": number // 列号 + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 } }, - "source": string, // 可选,来源 - "related_information": [ // 相关信息数组 + source?: string, // 可选,错误来源(如"eslint", "pyright") + related_information: [ // Diagnostic.RelatedInformation 相关信息 { - "message": string, // 相关信息消息 - "range": { // 相关信息范围 - "start_position": { // 开始位置 - "line": number, // 行号 - "column": number // 列号 + message: string, // 相关信息消息 + range: { // CursorRange 相关信息范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 }, - "end_position": { // 结束位置 - "line": number, // 行号 - "column": number // 列号 + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 } } } ], - "severity": "error" | "warning" | "information" | "hint" // 可选,严重程度 + severity?: "error" | "warning" | "information" | "hint" // 可选,DiagnosticSeverity 严重程度 } ], - "file_contents": string // 文件内容 + file_contents: string // 文件内容(用于错误上下文) }, - "context_items": [ // 上下文项 + context_items: [ // CppContextItem 上下文项数组 { - "contents": string, // 内容 - "symbol": string, // 可选,符号 - "relative_workspace_path": string, // 相对工作区路径 - "score": number // 分数 + contents: string, // 上下文内容 + symbol?: string, // 可选,符号名称 + relative_workspace_path: string, // 上下文所在文件相对路径 + score: float // 相关性分数 } ], - "diff_history_keys": [string], // 差异历史键 - "give_debug_output": boolean, // 可选,提供调试输出 - "file_diff_histories": [ // 文件差异历史 + diff_history_keys: [string], // 差异历史键(已弃用) + give_debug_output?: bool, // 可选,是否输出调试信息 + file_diff_histories: [ // CppFileDiffHistory 文件差异历史数组 { - "file_name": string, // 文件名 - "diff_history": [string], // 差异历史 - "diff_history_timestamps": [number] // 差异历史时间戳 + file_name: string, // 文件名 + diff_history: [string], // 差异历史数组,格式:"行号-|旧内容\n行号+|新内容\n" + diff_history_timestamps: [double] // 差异时间戳数组(Unix毫秒时间戳) } ], - "merged_diff_histories": [ // 合并差异历史 + merged_diff_histories: [ // CppFileDiffHistory 合并后的差异历史 { - "file_name": string, // 文件名 - "diff_history": [string], // 差异历史 - "diff_history_timestamps": [number] // 差异历史时间戳 + file_name: string, // 文件名 + diff_history: [string], // 合并后的差异历史 + diff_history_timestamps: [double] // 时间戳数组 } ], - "block_diff_patches": [ // 块差异补丁 + block_diff_patches: [ // BlockDiffPatch 块级差异补丁 { - "start_model_window": { // 开始模型窗口 - "lines": [string], // 行 - "start_line_number": number, // 开始行号 - "end_line_number": number // 结束行号 + start_model_window: { // ModelWindow 模型窗口起始状态 + lines: [string], // 窗口内的代码行 + start_line_number: int32, // 窗口起始行号 + end_line_number: int32 // 窗口结束行号 }, - "changes": [ // 变更 + changes: [ // Change 变更数组 { - "text": string, // 文本 - "range": { // 范围 - "start_line_number": number, // 开始行号 - "start_column": number, // 开始列 - "end_line_number": number, // 结束行号 - "end_column": number // 结束列 + text: string, // 变更后的文本 + range: { // IRange 变更范围 + start_line_number: int32, // 起始行号 + start_column: int32, // 起始列号 + end_line_number: int32, // 结束行号 + end_column: int32 // 结束列号 } } ], - "relative_path": string, // 相对路径 - "model_uuid": string, // 模型UUID - "start_from_change_index": number // 开始变更索引 + relative_path: string, // 文件相对路径 + model_uuid: string, // 模型UUID(用于追踪补全来源) + start_from_change_index: int32 // 从第几个change开始应用 } ], - "is_nightly": boolean, // 可选,是否为nightly版本 - "is_debug": boolean, // 可选,是否为调试模式 - "immediately_ack": boolean, // 可选,立即确认 - "enable_more_context": boolean, // 可选,启用更多上下文 - "parameter_hints": [ // 参数提示 + is_nightly?: bool, // 可选,是否为nightly构建版本 + is_debug?: bool, // 可选,是否为调试模式 + immediately_ack?: bool, // 可选,是否立即确认请求 + enable_more_context?: bool, // 可选,是否启用更多上下文检索 + parameter_hints: [ // CppParameterHint 参数提示数组 { - "label": string, // 标签 - "documentation": string // 可选,文档 + label: string, // 参数标签(如"x: int") + documentation?: string // 可选,参数文档说明 } ], - "lsp_contexts": [ // LSP上下文 + lsp_contexts: [ // LspSubgraphFullContext LSP子图上下文 { - "uri": string, // URI - "symbol_name": string, // 符号名称 - "positions": [ // 位置 + uri?: string, // 可选,文件URI + symbol_name: string, // 符号名称 + positions: [ // LspSubgraphPosition 位置数组 { - "line": number, // 行 - "character": number // 字符 + line: int32, // 行号 + character: int32 // 字符位置 } ], - "context_items": [ // 上下文项 + context_items: [ // LspSubgraphContextItem 上下文项 { - "uri": string, // 可选,URI - "type": string, // 类型 - "content": string, // 内容 - "range": { // 可选,范围 - "start_line": number, // 开始行 - "start_character": number, // 开始字符 - "end_line": number, // 结束行 - "end_character": number // 结束字符 + uri?: string, // 可选,URI + type: string, // 类型(如"definition", "reference") + content: string, // 内容 + range?: { // 可选,LspSubgraphRange 范围 + start_line: int32, // 起始行 + start_character: int32, // 起始字符 + end_line: int32, // 结束行 + end_character: int32 // 结束字符 } } ], - "score": number // 分数 + score: float // 相关性分数 } ], - "cpp_intent_info": { // 可选,代码补全意图信息 - "source": string // 来源 + cpp_intent_info?: { // 可选,CppIntentInfo 代码补全意图信息 + source: "line_change" | "typing" | "option_hold" | // 触发来源 + "linter_errors" | "parameter_hints" | + "cursor_prediction" | "manual_trigger" | + "editor_change" | "lsp_suggestions" }, - "workspace_id": string, // 可选,工作区ID - "additional_files": [ // 附加文件 + workspace_id?: string, // 可选,工作区唯一标识符 + additional_files: [ // AdditionalFile 附加文件数组 { - "relative_workspace_path": string, // 相对工作区路径 - "is_open": boolean, // 是否打开 - "visible_range_content": [string], // 可见范围内容 - "last_viewed_at": number, // 可选,最后查看时间 - "start_line_number_one_indexed": [number], // 从1开始索引的起始行号 - "visible_ranges": [ // 可见范围 + relative_workspace_path: string, // 文件相对路径 + is_open: bool, // 是否在编辑器中打开 + visible_range_content: [string], // 可见范围的内容(按行) + last_viewed_at?: double, // 可选,最后查看时间(Unix毫秒时间戳) + start_line_number_one_indexed: [int32], // 可见范围起始行号(1-based索引) + visible_ranges: [ // LineRange 可见范围数组 { - "start_line_number": number, // 开始行号 - "end_line_number_inclusive": number // 结束行号(包含) + start_line_number: int32, // 起始行号 + end_line_number_inclusive: int32 // 结束行号(包含) } ] } ], - "control_token": "quiet" | "loud" | "op", // 可选,控制标记 - "client_time": number, // 可选,客户端时间 - "filesync_updates": [ // 文件同步更新 + control_token?: "quiet" | "loud" | "op", // 可选,ControlToken 控制标记 + client_time?: double, // 可选,客户端时间(Unix毫秒时间戳) + filesync_updates: [ // FilesyncUpdateWithModelVersion 文件同步增量更新 { - "model_version": number, // 模型版本 - "relative_workspace_path": string, // 相对工作区路径 - "updates": [ // 更新数组 + model_version: int32, // 模型版本号 + relative_workspace_path: string, // 文件相对路径 + updates: [ // SingleUpdateRequest 更新操作数组 { - "start_position": number, // 开始位置(字符偏移量) - "end_position": number, // 结束位置(字符偏移量) - "change_length": number, // 变更长度 - "replaced_string": string, // 替换的字符串 - "range": { // 范围 - "start_line_number": number, // 开始行号 - "start_column": number, // 开始列 - "end_line_number_inclusive": number, // 结束行号(包含) - "end_column": number // 结束列 + start_position: int32, // 起始位置(字符偏移量,0-based) + end_position: int32, // 结束位置(字符偏移量,0-based) + change_length: int32, // 变更后的长度 + replaced_string: string, // 替换的字符串内容 + range: { // SimpleRange 变更范围 + start_line_number: int32, // 起始行号 + start_column: int32, // 起始列号 + end_line_number_inclusive: int32, // 结束行号(包含) + end_column: int32 // 结束列号 } } ], - "expected_file_length": number // 预期文件长度 + expected_file_length: int32 // 应用更新后预期的文件长度 } ], - "time_since_request_start": number, // 请求开始后的时间 - "time_at_request_send": number, // 请求发送时的时间 - "client_timezone_offset": number, // 可选,客户端时区偏移 - "lsp_suggested_items": { // 可选,LSP建议项 - "suggestions": [ // 建议 + time_since_request_start: double, // 从请求开始到当前的时间(毫秒) + time_at_request_send: double, // 请求发送时的时间戳(Unix毫秒时间戳) + client_timezone_offset?: double, // 可选,客户端时区偏移(分钟,如-480表示UTC+8) + lsp_suggested_items?: { // 可选,LspSuggestedItems LSP建议项 + suggestions: [ // LspSuggestion 建议数组 { - "label": string // 标签 + label: string // 建议标签 } ] }, - "supports_cpt": boolean // 可选,是否支持CPT + supports_cpt?: bool, // 可选,是否支持CPT(Code Patch Token)格式 + supports_crlf_cpt?: bool, // 可选,是否支持CRLF换行的CPT格式 + code_results: [ // CodeResult 代码检索结果 + { + code_block: { // CodeBlock 代码块 + relative_workspace_path: string, // 文件相对路径 + file_contents?: string, // 可选,完整文件内容 + file_contents_length?: int32, // 可选,文件内容长度 + range: { // CursorRange 代码块范围 + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 + }, + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 + } + }, + contents: string, // 代码块内容 + signatures: { // Signatures 签名信息 + ranges: [ // CursorRange 签名范围数组 + { + start_position: { // CursorPosition 开始位置 + line: int32, // 行号 + column: int32 // 列号 + }, + end_position: { // CursorPosition 结束位置 + line: int32, // 行号 + column: int32 // 列号 + } + } + ] + }, + override_contents?: string, // 可选,覆盖内容 + original_contents?: string, // 可选,原始内容 + detailed_lines: [ // DetailedLine 详细行信息 + { + text: string, // 行文本 + line_number: float, // 行号(浮点数用于支持虚拟行) + is_signature: bool // 是否为签名行 + } + ], + file_git_context: { // FileGit Git上下文 + commits: [ // GitCommit 提交数组 + { + commit: string, // 提交哈希 + author: string, // 作者 + date: string, // 提交日期 + message: string // 提交消息 + } + ] + } + }, + score: float // 检索相关性分数 + } + ] } ``` -### 响应格式 (SSE流格式) +### 响应格式 (SSE 流) -事件类型及对应数据格式: +服务器通过 Server-Sent Events (SSE) 返回流式响应。每个事件包含 `type` 字段区分消息类型。 -1. **model_info** -```json +--- + +#### 事件类型 + +**1. model_info** - 模型信息 +```typescript { - "type": "model_info", - "is_fused_cursor_prediction_model": boolean, - "is_multidiff_model": boolean + type: "model_info", + is_fused_cursor_prediction_model: bool, + is_multidiff_model: bool } ``` -2. **range_replace** -```json +--- + +**2. range_replace** - 范围替换 +```typescript { - "type": "range_replace", - "start_line_number": number, - "end_line_number_inclusive": number, - "text": string + type: "range_replace", + start_line_number: int32, // 起始行(1-based) + end_line_number_inclusive: int32, // 结束行(1-based,包含) + binding_id?: string, + should_remove_leading_eol?: bool } ``` +> **注意**:替换的文本内容通过后续的 `text` 事件发送 -3. **cursor_prediction** -```json +--- + +**3. text** - 文本内容 +```typescript { - "type": "cursor_prediction", - "relative_path": string, - "line_number_one_indexed": number, - "expected_content": string, - "should_retrigger_cpp": boolean + type: "text", + text: string } ``` +> **说明**:流式输出的主要内容,客户端应累积 -4. **text** -```json +--- + +**4. cursor_prediction** - 光标预测 +```typescript { - "type": "text", - "text": string + type: "cursor_prediction", + relative_path: string, + line_number_one_indexed: int32, + expected_content: string, + should_retrigger_cpp: bool, + binding_id?: string } ``` -5. **done_edit** -```json +--- + +**5. done_edit** - 编辑完成 +```typescript { - "type": "done_edit" + type: "done_edit" } ``` -6. **done_stream** -```json +--- + +**6. begin_edit** - 编辑开始 +```typescript { - "type": "done_stream" + type: "begin_edit" } ``` -7. **debug** -```json +--- + +**7. done_stream** - 内容阶段结束 +```typescript { - "type": "debug", - "model_input": string, - "model_output": string, - "total_time": string, // 可选 - "stream_time": string, - "ttft_time": string, - "server_timing": string // 可选 + type: "done_stream" } ``` +> **说明**:之后可能会有 `debug` 消息 -8. **error** -```json +--- + +**8. debug** - 调试信息 +```typescript { - "type": "error", - "message": string + type: "debug", + model_input?: string, + model_output?: string, + stream_time?: string, + total_time?: string, + ttft_time?: string, + server_timing?: string } ``` +> **说明**:可能出现多次,前端可累积用于统计 -9. **stream_end** -```json +--- + +**9. error** - 错误 +```typescript { - "type": "stream_end" + type: "error", + error: { + code: uint16, // 非零错误码 + type: string, // 错误类型 + details?: { // 可选的详细信息 + title: string, + detail: string, + additional_info?: Record + } + } } ``` -#### 来源可选值 +--- -- line_change -- typing -- option_hold -- linter_errors -- parameter_hints -- cursor_prediction -- manual_trigger -- editor_change -- lsp_suggestions +**10. stream_end** - 流结束 +```typescript +{ + type: "stream_end" +} +``` + +--- + +#### 典型消息序列 + +**基础场景:** +``` +model_info +range_replace // 指定范围 +text (×N) // 流式文本 +done_edit +done_stream +debug (×N) // 可选的多个调试消息 +stream_end +``` + +**多次编辑:** +``` +model_info +range_replace +text (×N) +done_edit +begin_edit // 下一次编辑 +range_replace +text (×N) +cursor_prediction // 可选 +done_edit +done_stream +stream_end +``` + +--- + +#### 客户端处理要点 + +1. **累积文本** + - `range_replace` 指定范围 + - 累积后续所有 `text` 内容 + - `done_edit` 时应用变更 + +2. **换行符处理** + - `should_remove_leading_eol=true` 时移除首个换行符 + +3. **多编辑会话** + - `begin_edit` 标记新会话开始 + - `binding_id` 用于关联同一补全的多个编辑 + +4. **错误处理** + - 流中出现 `error` 时,客户端应中止当前操作 + +5. **调试信息** + - `done_stream` 后可能有多个 `debug` 消息 + - 前端可累积用于性能分析 ## 鸣谢 @@ -1934,4 +2207,18 @@ string - [cursor-api](https://github.com/wisdgod/cursor-api) - 本项目本身 - [zhx47/cursor-api](https://github.com/zhx47/cursor-api) - 提供了本项目起步阶段的主要参考 -- [luolazyandlazy/cursorToApi](https://github.com/luolazyandlazy/cursorToApi) +- [luolazyandlazy/cursorToApi](https://github.com/luolazyandlazy/cursorToApi) - zhx47/cursor-api基于此项目优化 + +## 关于赞助 + +非常感谢我自己持续8个多月的更新和大家的支持!你想赞助的话,清直接联系我,我一般不会拒绝。 + +有人说少个二维码来着,还是算了。如果觉得好用,给点支持。没啥大不了的,有空尽量做一点,只是心力确实消耗很大。 + +~~要不给我邮箱发口令红包?~~ + +**赞助一定要是你真心想给,也不强求。** + +就算你给我赞助,我可能也不会区别对待你。我不想说你赞助多少就有什么,不想赞助失去本来的意味。 + +纯粹! diff --git a/VERSION b/VERSION index 7730ef7..25bf17f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -89 \ No newline at end of file +18 \ No newline at end of file diff --git a/build.rs b/build.rs index f55f352..afbc445 100644 --- a/build.rs +++ b/build.rs @@ -1,8 +1,8 @@ -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use sha2::{Digest, Sha256}; -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use std::collections::HashMap; -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use std::fs; #[cfg(not(debug_assertions))] #[cfg(feature = "__preview")] @@ -11,28 +11,29 @@ use std::io::Result; #[cfg(not(debug_assertions))] #[cfg(feature = "__preview")] use std::io::{Read, Write}; -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use std::path::Path; -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use std::path::PathBuf; -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] use std::process::Command; // 支持的文件类型 -#[cfg(not(any(feature = "use-minified")))] -const SUPPORTED_EXTENSIONS: [&str; 4] = ["html", "js", "css", "md"]; +// #[cfg(not(feature = "use-minified"))] +// const SUPPORTED_EXTENSIONS: [&str; 4] = ["html", "js", "css", "md"]; -#[cfg(not(any(feature = "use-minified")))] +// 需要处理的 Markdown 文件列表 +#[cfg(not(feature = "use-minified"))] +const MARKDOWN_FILES: [&str; 2] = ["README.md", "LICENSE.md"]; + +#[cfg(not(feature = "use-minified"))] fn check_and_install_deps() -> Result<()> { let scripts_dir = Path::new("scripts"); let node_modules = scripts_dir.join("node_modules"); if !node_modules.exists() { println!("cargo:warning=Installing minifier dependencies..."); - let status = Command::new("npm") - .current_dir(scripts_dir) - .arg("install") - .status()?; + let status = Command::new("npm").current_dir(scripts_dir).arg("install").status()?; if !status.success() { panic!("Failed to install npm dependencies"); @@ -42,67 +43,100 @@ fn check_and_install_deps() -> Result<()> { Ok(()) } -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] fn get_files_hash() -> Result> { let mut file_hashes = HashMap::new(); - let static_dir = Path::new("static"); + // let static_dir = Path::new("static"); - // 首先处理 README.md - let readme_path = Path::new("README.md"); - if readme_path.exists() { - let content = fs::read(readme_path)?; - let hash = format!("{:x}", Sha256::new().chain_update(&content).finalize()); - file_hashes.insert(readme_path.to_path_buf(), hash); + pub const HEX_CHARS: &[u8; 16] = b"0123456789abcdef"; + + #[inline] + pub fn to_str<'buf>(bytes: &[u8], buf: &'buf mut [u8]) -> &'buf mut str { + for (i, &byte) in bytes.iter().enumerate() { + buf[i * 2] = HEX_CHARS[(byte >> 4) as usize]; + buf[i * 2 + 1] = HEX_CHARS[(byte & 0x0f) as usize]; + } + + // SAFETY: 输出都是有效的 ASCII 字符 + unsafe { core::str::from_utf8_unchecked_mut(buf) } } - if static_dir.exists() { - for entry in fs::read_dir(static_dir)? { - let entry = entry?; - let path = entry.path(); - - // 检查是否是支持的文件类型,且不是已经压缩的文件 - if let Some(ext) = path.extension().and_then(|e| e.to_str()) - && SUPPORTED_EXTENSIONS.contains(&ext) - && !path.to_string_lossy().contains(".min.") - { - let content = fs::read(&path)?; - let hash = format!("{:x}", Sha256::new().chain_update(&content).finalize()); - file_hashes.insert(path, hash); - } + // 处理根目录的 Markdown 文件 + for md_file in MARKDOWN_FILES { + let md_path = Path::new(md_file); + if md_path.exists() { + let content = fs::read(md_path)?; + #[allow(deprecated)] + let hash = + to_str(Sha256::new().chain_update(&content).finalize().as_slice(), &mut [0; 64]) + .to_string(); + file_hashes.insert(md_path.to_path_buf(), hash); } } + // 处理 static 目录中的文件 + // if static_dir.exists() { + // for entry in fs::read_dir(static_dir)? { + // let entry = entry?; + // let path = entry.path(); + + // // 检查是否是支持的文件类型,且不是已经压缩的文件 + // if let Some(ext) = path.extension().and_then(|e| e.to_str()) + // && SUPPORTED_EXTENSIONS.contains(&ext) + // && !path.to_string_lossy().contains(".min.") + // { + // let content = fs::read(&path)?; + // #[allow(deprecated)] + // let hash = + // to_str(Sha256::new().chain_update(&content).finalize().as_slice(), &mut [0; 64]) + // .to_string(); + // file_hashes.insert(path, hash); + // } + // } + // } + Ok(file_hashes) } -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] fn load_saved_hashes() -> Result> { let hash_file = Path::new("scripts/.asset-hashes.json"); if hash_file.exists() { let content = fs::read_to_string(hash_file)?; let hash_map: HashMap = serde_json::from_str(&content)?; - Ok(hash_map - .into_iter() - .map(|(k, v)| (PathBuf::from(k), v)) - .collect()) + Ok(hash_map.into_iter().map(|(k, v)| (PathBuf::from(k), v)).collect()) } else { Ok(HashMap::new()) } } -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] fn save_hashes(hashes: &HashMap) -> Result<()> { let hash_file = Path::new("scripts/.asset-hashes.json"); - let string_map: HashMap = hashes - .iter() - .map(|(k, v)| (k.to_string_lossy().into_owned(), v.clone())) - .collect(); + let string_map: HashMap = + hashes.iter().map(|(k, v)| (k.to_string_lossy().into_owned(), v.clone())).collect(); let content = serde_json::to_string_pretty(&string_map)?; fs::write(hash_file, content)?; Ok(()) } -#[cfg(not(any(feature = "use-minified")))] +#[cfg(not(feature = "use-minified"))] +fn get_minified_output_path(path: &Path) -> PathBuf { + let file_name = path.file_name().and_then(|f| f.to_str()).unwrap_or(""); + + // 检查是否是根目录的 Markdown 文件 + if MARKDOWN_FILES.contains(&file_name) { + // 将文件名转换为小写并生成对应的 .min.html 文件 + let base_name = path.file_stem().unwrap().to_string_lossy().to_lowercase(); + PathBuf::from(format!("static/{}.min.html", base_name)) + } else { + // 其他文件保持原有逻辑 + let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); + path.with_file_name(format!("{}.min.{}", path.file_stem().unwrap().to_string_lossy(), ext)) + } +} + +#[cfg(not(feature = "use-minified"))] fn minify_assets() -> Result<()> { // 获取现有文件的哈希 let current_hashes = get_files_hash()?; @@ -119,19 +153,8 @@ fn minify_assets() -> Result<()> { let files_to_update: Vec<_> = current_hashes .iter() .filter(|(path, current_hash)| { - let is_readme = path.file_name().is_some_and(|f| f == "README.md"); - let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); - - // 为 README.md 和其他文件使用不同的输出路径检查 - let min_path = if is_readme { - PathBuf::from("static/readme.min.html") - } else { - path.with_file_name(format!( - "{}.min.{}", - path.file_stem().unwrap().to_string_lossy(), - ext - )) - }; + // 获取压缩后的输出路径 + let min_path = get_minified_output_path(path); // 检查压缩/转换后的文件是否存在 if !min_path.exists() { @@ -150,13 +173,10 @@ fn minify_assets() -> Result<()> { } println!("cargo:warning=Minifying {} files...", files_to_update.len()); - println!("cargo:warning={}", files_to_update.join(" ")); + println!("cargo:warning=Files: {}", files_to_update.join(" ")); // 运行压缩脚本 - let status = Command::new("node") - .arg("scripts/minify.js") - .args(&files_to_update) - .status()?; + let status = Command::new("node").arg("scripts/minify.js").args(&files_to_update).status()?; if !status.success() { panic!("Asset minification failed"); @@ -291,35 +311,31 @@ fn main() -> Result<()> { } config - .compile_protos( - &["src/core/aiserver/v1/lite.proto"], - &["src/core/aiserver/v1/"], - ) - .unwrap(); - config - .compile_protos(&["src/core/config/key.proto"], &["src/core/config/"]) + .compile_protos(&["src/core/aiserver/v1/lite.proto"], &["src/core/aiserver/v1/"]) .unwrap(); + // config.compile_protos(&["src/core/config/key.proto"], &["src/core/config/"]).unwrap(); } // 静态资源文件处理 println!("cargo:rerun-if-changed=scripts/minify.js"); - println!("cargo:rerun-if-changed=scripts/package.json"); - println!("cargo:rerun-if-changed=static/api.html"); - println!("cargo:rerun-if-changed=static/build_key.html"); - println!("cargo:rerun-if-changed=static/config.html"); - println!("cargo:rerun-if-changed=static/logs.html"); - println!("cargo:rerun-if-changed=static/proxies.html"); - println!("cargo:rerun-if-changed=static/shared-styles.css"); - println!("cargo:rerun-if-changed=static/shared.js"); - println!("cargo:rerun-if-changed=static/tokens.html"); + // println!("cargo:rerun-if-changed=scripts/package.json"); + // println!("cargo:rerun-if-changed=static/api.html"); + // println!("cargo:rerun-if-changed=static/build_key.html"); + // println!("cargo:rerun-if-changed=static/config.html"); + // println!("cargo:rerun-if-changed=static/logs.html"); + // println!("cargo:rerun-if-changed=static/proxies.html"); + // println!("cargo:rerun-if-changed=static/shared-styles.css"); + // println!("cargo:rerun-if-changed=static/shared.js"); + // println!("cargo:rerun-if-changed=static/tokens.html"); println!("cargo:rerun-if-changed=README.md"); + println!("cargo:rerun-if-changed=LICENSE.md"); // 只在release模式下监控VERSION文件变化 #[cfg(not(debug_assertions))] #[cfg(feature = "__preview")] println!("cargo:rerun-if-changed=VERSION"); - #[cfg(not(any(feature = "use-minified")))] + #[cfg(not(feature = "use-minified"))] { // 检查并安装依赖 check_and_install_deps()?; diff --git a/build_info.rs b/build_info.rs index 5a744c0..3d3a27d 100644 --- a/build_info.rs +++ b/build_info.rs @@ -1,3 +1,109 @@ +include!("src/app/model/version.rs"); + +/// 版本字符串解析错误 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParseError { + /// 整体格式错误(如缺少必需部分) + InvalidFormat, + /// 数字解析失败 + InvalidNumber, + /// pre 部分格式错误 + InvalidPreRelease, + /// build 部分格式错误 + InvalidBuild, + // /// 正式版不能带 build 标识 + // BuildWithoutPreview, +} + +impl core::fmt::Display for ParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ParseError::InvalidFormat => write!(f, "invalid version format"), + ParseError::InvalidNumber => write!(f, "invalid number in version"), + ParseError::InvalidPreRelease => write!(f, "invalid pre-release format"), + ParseError::InvalidBuild => write!(f, "invalid build format"), + // ParseError::BuildWithoutPreview => { + // write!(f, "build metadata cannot exist without pre-release version") + // } + } + } +} + +impl std::error::Error for ParseError {} + +impl core::str::FromStr for Version { + type Err = ParseError; + + fn from_str(s: &str) -> core::result::Result { + // 按 '-' 分割基础版本号和扩展部分 + let (base, extension) = match s.split_once('-') { + Some((base, ext)) => (base, Some(ext)), + None => (s, None), + }; + + // 解析基础版本号 major.minor.patch + let mut parts: [u16; 3] = [0, 0, 0]; + let mut parsed_count = 0; + for (i, s) in base.split('.').enumerate() { + if i >= parts.len() { + return Err(ParseError::InvalidFormat); + } + parts[i] = s.parse().map_err(|_| ParseError::InvalidNumber)?; + parsed_count += 1; + } + if parsed_count != 3 { + return Err(ParseError::InvalidFormat); + } + + let major = parts[0]; + let minor = parts[1]; + let patch = parts[2]; + + // 解析扩展部分(如果存在) + let stage = + if let Some(ext) = extension { parse_extension(ext)? } else { ReleaseStage::Release }; + + Ok(Version { major, minor, patch, stage }) + } +} + +/// 解析扩展部分:pre.X 或 pre.X+build.Y +fn parse_extension(s: &str) -> core::result::Result { + // 检查是否以 "pre." 开头 + if !s.starts_with("pre.") { + return Err(ParseError::InvalidPreRelease); + } + + // 移除 "pre." 前缀 + let after_pre = &s[4..]; + + // 按 '+' 分割 version 和 build 部分 + let (version_str, build_str) = match after_pre.split_once('+') { + Some((ver, build_part)) => (ver, Some(build_part)), + None => (after_pre, None), + }; + + // 解析 pre 版本号 + let version = version_str.parse().map_err(|_| ParseError::InvalidPreRelease)?; + + // 解析 build 号(如果存在) + let build = if let Some(build_part) = build_str { + // 检查格式是否为 "build.X" + if !build_part.starts_with("build.") { + return Err(ParseError::InvalidBuild); + } + + let build_num_str = &build_part[6..]; + let build_num = build_num_str.parse().map_err(|_| ParseError::InvalidBuild)?; + + Some(build_num) + } else { + None + }; + + Ok(ReleaseStage::Preview { version, build }) +} + /** * 更新版本号函数 * 此函数会读取 VERSION 文件中的数字,将其加1,然后保存回文件 @@ -26,6 +132,7 @@ fn update_version() -> Result<()> { file.read_to_string(&mut version)?; // 确保版本号是有效数字 + #[allow(unused_variables)] let version_num = match version.trim().parse::() { Ok(num) => num, Err(_) => { @@ -36,20 +143,23 @@ fn update_version() -> Result<()> { } }; - // 版本号加1 - let new_version = version_num + 1; - println!( - "cargo:warning=Release build - bumping version from {version_num} to {new_version}", - ); + #[cfg(not(feature = "__preview_locked"))] + { + // 版本号加1 + let new_version = version_num + 1; + println!( + "cargo:warning=Release build - bumping version from {version_num} to {new_version}", + ); - // 写回文件 - let mut file = File::create(version_path)?; - file.write_all(new_version.to_string().as_bytes())?; + // 写回文件 + let mut file = File::create(version_path)?; + write!(file, "{new_version}")?; + } Ok(()) } -#[cfg(feature = "__preview")] +#[allow(unused)] fn read_version_number() -> Result { let mut version = String::with_capacity(4); match std::fs::File::open("VERSION") { @@ -63,21 +173,19 @@ fn read_version_number() -> Result { } fn generate_build_info() -> Result<()> { - // let out_dir = std::env::var("OUT_DIR").unwrap(); - // let dest_path = Path::new(out_dir).join("build_info.rs"); - #[cfg(debug_assertions)] - let out_dir = "target/debug/build/build_info.rs"; - #[cfg(not(debug_assertions))] - let out_dir = "target/release/build/build_info.rs"; - let dest_path = Path::new(out_dir); + let out_dir = std::env::var("OUT_DIR").unwrap(); + let dest_path = Path::new(&out_dir).join("build_info.rs"); + // #[cfg(debug_assertions)] + // let out_dir = "../target/debug/build/build_info.rs"; + // #[cfg(not(debug_assertions))] + // let out_dir = "../target/release/build/build_info.rs"; + // let dest_path = Path::new(out_dir); // if dest_path.is_file() { // return Ok(()); // } - let build_timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(); + let build_timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); let build_timestamp_str = chrono::DateTime::from_timestamp(build_timestamp as i64, 0) .unwrap() @@ -85,56 +193,43 @@ fn generate_build_info() -> Result<()> { let pkg_version = env!("CARGO_PKG_VERSION"); - #[cfg(feature = "__preview")] - let (version_str, build_version_str) = { - let build_num = read_version_number()?; - ( - format!("{pkg_version}+build.{build_num}"), - format!("pub const BUILD_VERSION: u32 = {build_num};\n"), - ) - }; + let (version_str, build_version_str) = + if cfg!(feature = "__preview") && pkg_version.contains("-pre") { + let build_num = read_version_number()?; + ( + format!("{pkg_version}+build.{build_num}"), + format!("pub const BUILD_VERSION: u32 = {build_num};\n"), + ) + } else { + (pkg_version.to_string(), String::new()) + }; - #[cfg(not(feature = "__preview"))] - let (version_str, build_version_str) = (pkg_version, ""); + let version: Version = version_str.parse().unwrap(); let build_info_content = format!( r#"// 此文件由 build.rs 自动生成,请勿手动修改 +use crate::app::model::version::{{Version, ReleaseStage::Preview}}; {build_version_str}pub const BUILD_TIMESTAMP: &'static str = {build_timestamp_str:?}; -pub const VERSION: &'static str = {version_str:?}; +/// pub const VERSION_STR: &'static str = {version_str:?}; +pub const VERSION: Version = {version:?}; pub const IS_PRERELEASE: bool = {is_prerelease}; pub const IS_DEBUG: bool = {is_debug}; #[cfg(unix)] -pub const BUILD_EPOCH: std::time::SystemTime = unsafe {{ - #[allow(dead_code)] - struct UnixSystemTime {{ - tv_sec: i64, - tv_nsec: u32, - }} - - ::core::mem::transmute(UnixSystemTime {{ - tv_sec: {build_timestamp}, - tv_nsec: 0, - }}) -}}; +pub const BUILD_EPOCH: std::time::SystemTime = + unsafe {{ ::core::intrinsics::transmute(({build_timestamp}i64, 0u32)) }}; #[cfg(windows)] pub const BUILD_EPOCH: std::time::SystemTime = unsafe {{ - #[allow(dead_code)] - struct WindowsFileTime {{ - dw_low_date_time: u32, - dw_high_date_time: u32, - }} - const INTERVALS_PER_SEC: u64 = 10_000_000; const INTERVALS_TO_UNIX_EPOCH: u64 = 11_644_473_600 * INTERVALS_PER_SEC; const TARGET_INTERVALS: u64 = INTERVALS_TO_UNIX_EPOCH + {build_timestamp} * INTERVALS_PER_SEC; - ::core::mem::transmute(WindowsFileTime {{ - dw_low_date_time: TARGET_INTERVALS as u32, - dw_high_date_time: (TARGET_INTERVALS >> 32) as u32, - }}) + ::core::intrinsics::transmute(( + TARGET_INTERVALS as u32, + (TARGET_INTERVALS >> 32) as u32, + )) }}; "#, is_prerelease = cfg!(feature = "__preview"), diff --git a/crates/grpc-stream/Cargo.toml b/crates/grpc-stream/Cargo.toml new file mode 100644 index 0000000..1a3d2e5 --- /dev/null +++ b/crates/grpc-stream/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "grpc-stream" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +bytes = "1" +flate2 = "1" +prost = "0.14" diff --git a/crates/grpc-stream/src/buffer.rs b/crates/grpc-stream/src/buffer.rs new file mode 100644 index 0000000..5144526 --- /dev/null +++ b/crates/grpc-stream/src/buffer.rs @@ -0,0 +1,195 @@ +//! 内部缓冲区管理 + +use core::iter::FusedIterator; + +use bytes::{Buf as _, BytesMut}; + +use crate::frame::RawMessage; + +/// 消息缓冲区(内部使用) +pub struct Buffer { + inner: BytesMut, +} + +impl Buffer { + #[inline] + pub fn new() -> Self { Self { inner: BytesMut::new() } } + + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { inner: BytesMut::with_capacity(capacity) } + } + + #[inline] + pub fn len(&self) -> usize { self.inner.len() } + + #[inline] + pub fn is_empty(&self) -> bool { self.inner.is_empty() } + + #[inline] + pub fn extend_from_slice(&mut self, data: &[u8]) { self.inner.extend_from_slice(data) } + + #[inline] + pub fn advance(&mut self, cnt: usize) { self.inner.advance(cnt) } +} + +impl Default for Buffer { + #[inline] + fn default() -> Self { Self::new() } +} + +impl AsRef<[u8]> for Buffer { + #[inline] + fn as_ref(&self) -> &[u8] { self.inner.as_ref() } +} + +/// 消息迭代器(内部使用) +#[derive(Debug, Clone)] +pub struct MessageIter<'b> { + buffer: &'b [u8], + offset: usize, +} + +impl<'b> MessageIter<'b> { + /// 返回当前已消耗的字节数 + #[inline] + pub fn offset(&self) -> usize { self.offset } +} + +impl<'b> Iterator for MessageIter<'b> { + type Item = RawMessage<'b>; + + #[inline] + fn next(&mut self) -> Option { + // 至少需要 5 字节(1 字节 type + 4 字节 length) + if self.offset + 5 > self.buffer.len() { + return None; + } + + let r#type = unsafe { + let ptr: *const u8 = + ::core::intrinsics::slice_get_unchecked(self.buffer as *const [u8], self.offset); + *ptr + }; + let msg_len = u32::from_be_bytes(unsafe { + *get_offset_len_noubcheck(self.buffer, self.offset + 1, 4).cast() + }) as usize; + + // 检查消息是否完整 + if self.offset + 5 + msg_len > self.buffer.len() { + return None; + } + + self.offset += 5; + + let data = unsafe { &*get_offset_len_noubcheck(self.buffer, self.offset, msg_len) }; + + self.offset += msg_len; + + Some(RawMessage { r#type, data }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // 精确计算剩余完整消息数量 + let mut count = 0; + let mut offset = self.offset; + + while offset + 5 <= self.buffer.len() { + let msg_len = u32::from_be_bytes(unsafe { + *get_offset_len_noubcheck(self.buffer, offset + 1, 4).cast() + }) as usize; + + if offset + 5 + msg_len > self.buffer.len() { + break; + } + + count += 1; + offset += 5 + msg_len; + } + + (count, Some(count)) // 精确值 + } +} + +// 实现 ExactSizeIterator +impl<'b> ExactSizeIterator for MessageIter<'b> { + #[inline] + fn len(&self) -> usize { + // size_hint() 已经返回精确值,直接使用 + self.size_hint().0 + } +} + +// 实现 FusedIterator +impl<'b> FusedIterator for MessageIter<'b> {} + +impl<'b> IntoIterator for &'b Buffer { + type Item = RawMessage<'b>; + type IntoIter = MessageIter<'b>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { MessageIter { buffer: self.inner.as_ref(), offset: 0 } } +} + +#[inline(always)] +const unsafe fn get_offset_len_noubcheck( + ptr: *const [T], + offset: usize, + len: usize, +) -> *const [T] { + let ptr = ptr as *const T; + // SAFETY: The caller already checked these preconditions + let ptr = unsafe { ::core::intrinsics::offset(ptr, offset) }; + ::core::intrinsics::aggregate_raw_ptr(ptr, len) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_size_iterator() { + let mut buffer = Buffer::new(); + + // 构造两个消息:type=0, len=3, data="abc" + buffer.extend_from_slice(&[0, 0, 0, 0, 3, b'a', b'b', b'c']); + buffer.extend_from_slice(&[0, 0, 0, 0, 2, b'x', b'y']); + + let iter = (&buffer).into_iter(); + + // 验证 ExactSizeIterator + assert_eq!(iter.len(), 2); + assert_eq!(iter.size_hint(), (2, Some(2))); + + let messages: Vec<_> = iter.collect(); + assert_eq!(messages.len(), 2); + } + + #[test] + fn test_fused_iterator() { + let buffer = Buffer::new(); // 空缓冲区 + + let mut iter = (&buffer).into_iter(); + + // 验证 FusedIterator + assert_eq!(iter.next(), None); + assert_eq!(iter.next(), None); // 仍然是 None + assert_eq!(iter.next(), None); // 永远是 None + } + + #[test] + fn test_clone_iterator() { + let mut buffer = Buffer::new(); + buffer.extend_from_slice(&[0, 0, 0, 0, 3, b'a', b'b', b'c']); + + let iter = (&buffer).into_iter(); + let iter_clone = iter.clone(); + + // 消耗原迭代器 + assert_eq!(iter.count(), 1); + + // 副本仍然可用 + assert_eq!(iter_clone.count(), 1); + } +} diff --git a/crates/grpc-stream/src/compression.rs b/crates/grpc-stream/src/compression.rs new file mode 100644 index 0000000..9562038 --- /dev/null +++ b/crates/grpc-stream/src/compression.rs @@ -0,0 +1,154 @@ +//! 压缩数据处理 + +use std::io::Read as _; + +use flate2::read::GzDecoder; + +use crate::MAX_DECOMPRESSED_SIZE_BYTES; + +/// 解压 gzip 数据 +/// +/// # 参数 +/// - `data`: gzip 压缩的数据 +/// +/// # 返回 +/// - `Some(Vec)`: 解压成功 +/// - `None`: 不是有效的 gzip 数据或解压失败 +/// +/// # 最小 GZIP 文件结构 +/// +/// ```text +/// +----------+-------------+----------+ +/// | Header | DEFLATE | Footer | +/// | 10 bytes | 2+ bytes | 8 bytes | +/// +----------+-------------+----------+ +/// 最小: 10 + 2 + 8 = 20 字节 +/// ``` +/// +/// # 安全性 +/// - 限制解压后大小不超过 `MAX_DECOMPRESSED_SIZE_BYTES` +/// - 防止 gzip 炸弹攻击 +pub fn decompress_gzip(data: &[u8]) -> Option> { + // 快速路径:拒绝明显无效的数据 + // 最小有效 gzip 文件为 20 字节(头10 + 数据2 + 尾8) + if data.len() < 20 { + return None; + } + + // SAFETY: 上面已验证 data.len() >= 20,保证索引 0, 1, 2 有效 + // 检查 gzip 魔数(0x1f 0x8b)和压缩方法(0x08 = DEFLATE) + if unsafe { + *data.get_unchecked(0) != 0x1f + || *data.get_unchecked(1) != 0x8b + || *data.get_unchecked(2) != 0x08 + } { + return None; + } + + // 读取 gzip footer 中的 ISIZE(原始大小,最后 4 字节,小端序) + // SAFETY: 已验证 data.len() >= 20,末尾 4 字节必然有效 + let capacity = unsafe { + let ptr = data.as_ptr().add(data.len() - 4) as *const [u8; 4]; + u32::from_le_bytes(ptr.read()) as usize + }; + + // 防止解压炸弹攻击 + if capacity > MAX_DECOMPRESSED_SIZE_BYTES { + return None; + } + + // 执行实际解压 + let mut decoder = GzDecoder::new(data); + let mut decompressed = Vec::with_capacity(capacity); + + decoder.read_to_end(&mut decompressed).ok()?; + + Some(decompressed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_too_short() { + // 小于 20 字节的数据应该直接拒绝 + assert!(decompress_gzip(&[]).is_none()); + assert!(decompress_gzip(&[0x1f, 0x8b, 0x08]).is_none()); + assert!(decompress_gzip(&[0u8; 19]).is_none()); + } + + #[test] + fn test_invalid_magic() { + // 长度足够但魔数错误 + let mut data = vec![0u8; 20]; + data[0] = 0x00; // 错误的魔数 + data[1] = 0x8b; + data[2] = 0x08; + assert!(decompress_gzip(&data).is_none()); + + // 正确的第一字节,错误的第二字节 + data[0] = 0x1f; + data[1] = 0x00; + assert!(decompress_gzip(&data).is_none()); + + // 前两字节正确,压缩方法错误 + data[1] = 0x8b; + data[2] = 0x09; // 非 DEFLATE + assert!(decompress_gzip(&data).is_none()); + } + + #[test] + fn test_gzip_bomb_protection() { + // 构造声称解压后为 2MB 的假 gzip 数据 + let mut fake_gzip = vec![0x1f, 0x8b, 0x08]; // 正确的魔数 + fake_gzip.extend_from_slice(&[0u8; 14]); // 填充到 17 字节 + + // ISIZE 字段(最后 4 字节):2MB + let size_2mb = 2 * 1024 * 1024u32; + fake_gzip.extend_from_slice(&size_2mb.to_le_bytes()); + + assert_eq!(fake_gzip.len(), 21); // 17 + 4 + assert!(decompress_gzip(&fake_gzip).is_none()); + } + + #[test] + fn test_valid_gzip() { + // 使用标准库压缩一些数据 + use std::io::Write; + + use flate2::write::GzEncoder; + use flate2::Compression; + + let original = b"Hello, GZIP!"; + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(original).unwrap(); + let compressed = encoder.finish().unwrap(); + + // 验证:压缩数据 >= 20 字节 + assert!(compressed.len() >= 20); + + // 解压并验证 + let decompressed = decompress_gzip(&compressed).unwrap(); + assert_eq!(&decompressed, original); + } + + #[test] + fn test_empty_gzip() { + // 压缩空数据(最小有效 gzip) + use std::io::Write; + + use flate2::write::GzEncoder; + use flate2::Compression; + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&[]).unwrap(); + let compressed = encoder.finish().unwrap(); + + // 验证:最小 gzip 文件 ~20 字节 + assert!(compressed.len() >= 20); + + let decompressed = decompress_gzip(&compressed).unwrap(); + assert_eq!(decompressed.len(), 0); + } +} diff --git a/crates/grpc-stream/src/decoder.rs b/crates/grpc-stream/src/decoder.rs new file mode 100644 index 0000000..31d53c9 --- /dev/null +++ b/crates/grpc-stream/src/decoder.rs @@ -0,0 +1,135 @@ +//! 流式消息解码器 + +use prost::Message; + +use crate::buffer::Buffer; +use crate::compression::decompress_gzip; +use crate::frame::RawMessage; + +/// gRPC 流式消息解码器 +/// +/// 处理增量数据块,解析完整的 Protobuf 消息。 +/// +/// # 示例 +/// +/// ```no_run +/// use grpc_stream_decoder::StreamDecoder; +/// use prost::Message; +/// +/// #[derive(Message, Default)] +/// struct MyMessage { +/// #[prost(string, tag = "1")] +/// content: String, +/// } +/// +/// let mut decoder = StreamDecoder::new(); +/// +/// // 使用默认处理器 +/// loop { +/// let chunk = receive_network_data(); +/// let messages: Vec = decoder.decode_default(&chunk); +/// +/// for msg in messages { +/// process(msg); +/// } +/// } +/// +/// // 使用自定义处理器 +/// let messages = decoder.decode(&chunk, |raw_msg| { +/// // 自定义解码逻辑 +/// match raw_msg.r#type { +/// 0 => MyMessage::decode(raw_msg.data).ok(), +/// _ => None, +/// } +/// }); +/// ``` +pub struct StreamDecoder { + buffer: Buffer, +} + +impl StreamDecoder { + /// 创建新的解码器 + #[inline] + pub fn new() -> Self { Self { buffer: Buffer::new() } } + + /// 使用自定义处理器解码数据块 + /// + /// # 类型参数 + /// - `T`: 目标消息类型 + /// - `F`: 处理函数,签名为 `Fn(RawMessage<'_>) -> Option` + /// + /// # 参数 + /// - `data`: 接收到的数据块 + /// - `processor`: 自定义处理函数,接收原始消息并返回解码结果 + /// + /// # 返回 + /// 解码成功的消息列表 + /// + /// # 示例 + /// + /// ```no_run + /// // 自定义处理:只接受未压缩消息 + /// let messages = decoder.decode(&data, |raw_msg| { + /// if raw_msg.r#type == 0 { + /// MyMessage::decode(raw_msg.data).ok() + /// } else { + /// None + /// } + /// }); + /// ``` + pub fn decode(&mut self, data: &[u8], processor: F) -> Vec + where F: Fn(RawMessage<'_>) -> Option { + self.buffer.extend_from_slice(data); + + let mut iter = (&self.buffer).into_iter(); + let exact_count = iter.len(); + let mut messages = Vec::with_capacity(exact_count); + + for raw_msg in &mut iter { + if let Some(msg) = processor(raw_msg) { + messages.push(msg); + } + } + + self.buffer.advance(iter.offset()); + messages + } + + /// 使用默认处理器解码数据块 + /// + /// 默认行为: + /// - 类型 0:直接解码 Protobuf 消息 + /// - 类型 1:先 gzip 解压,再解码 + /// - 其他类型:忽略 + /// + /// # 类型参数 + /// - `T`: 实现 `prost::Message + Default` 的消息类型 + /// + /// # 参数 + /// - `data`: 接收到的数据块 + /// + /// # 返回 + /// 解码成功的消息列表 + pub fn decode_default(&mut self, data: &[u8]) -> Vec { + self.decode(data, |raw_msg| match raw_msg.r#type { + 0 => Self::decode_message(raw_msg.data), + 1 => Self::decode_compressed_message(raw_msg.data), + _ => None, + }) + } + + /// 解码未压缩消息 + #[inline] + fn decode_message(data: &[u8]) -> Option { T::decode(data).ok() } + + /// 解码 gzip 压缩消息 + #[inline] + fn decode_compressed_message(data: &[u8]) -> Option { + let decompressed = decompress_gzip(data)?; + Self::decode_message(&decompressed) + } +} + +impl Default for StreamDecoder { + fn default() -> Self { Self::new() } +} diff --git a/crates/grpc-stream/src/frame.rs b/crates/grpc-stream/src/frame.rs new file mode 100644 index 0000000..2454c55 --- /dev/null +++ b/crates/grpc-stream/src/frame.rs @@ -0,0 +1,54 @@ +//! 原始消息帧定义 + +/// gRPC 流式消息的原始帧 +/// +/// 包含帧头信息和消息数据的引用。 +/// +/// # 帧格式 +/// +/// ```text +/// +------+----------+----------------+ +/// | type | length | data | +/// | 1B | 4B (BE) | length bytes | +/// +------+----------+----------------+ +/// ``` +/// +/// - `type`: 消息类型 +/// - `0`: 未压缩 +/// - `1`: gzip 压缩 +/// - `length`: 消息体长度(大端序) +/// - `data`: 消息体数据 +/// +/// # 字段说明 +/// +/// - `r#type`: 帧类型标志(0=未压缩, 1=gzip) +/// - `data`: 消息体数据切片,其长度可通过 `data.len()` 获取 +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct RawMessage<'b> { + /// 消息类型(0=未压缩, 1=gzip) + pub r#type: u8, + + /// 消息体数据 + pub data: &'b [u8], +} + +impl RawMessage<'_> { + /// 计算该消息在缓冲区中占用的总字节数 + /// + /// 包含 5 字节帧头 + 消息体长度 + /// + /// # 示例 + /// + /// ``` + /// # use grpc_stream_decoder::RawMessage; + /// let msg = RawMessage { + /// r#type: 0, + /// data: &[1, 2, 3], + /// }; + /// assert_eq!(msg.total_size(), 8); // 5 + 3 + /// ``` + #[inline] + pub const fn total_size(&self) -> usize { + 5 + self.data.len() + } +} diff --git a/crates/grpc-stream/src/lib.rs b/crates/grpc-stream/src/lib.rs new file mode 100644 index 0000000..20ddfe5 --- /dev/null +++ b/crates/grpc-stream/src/lib.rs @@ -0,0 +1,46 @@ +//! gRPC 流式消息解码器 +//! +//! 提供高性能的 gRPC streaming 消息解析,支持 gzip 压缩。 +//! +//! # 示例 +//! +//! ```no_run +//! use grpc_stream_decoder::StreamDecoder; +//! use prost::Message; +//! +//! #[derive(Message, Default)] +//! struct MyMessage { +//! #[prost(string, tag = "1")] +//! content: String, +//! } +//! +//! let mut decoder = StreamDecoder::::new(); +//! +//! // 接收到的数据块 +//! let chunk = receive_data(); +//! let messages = decoder.decode(&chunk); +//! +//! for msg in messages { +//! println!("{}", msg.content); +//! } +//! ``` + +#![allow(internal_features)] +#![feature(core_intrinsics)] + +mod frame; +mod buffer; +mod compression; +mod decoder; + +// 公开 API +pub use frame::RawMessage; +pub use buffer::Buffer; +pub use compression::decompress_gzip; +pub use decoder::StreamDecoder; + +// 常量 +/// 最大解压缩消息大小限制(4 MiB) +/// +/// 对齐gRPC标准的默认最大消息大小,防止内存滥用攻击 +pub const MAX_DECOMPRESSED_SIZE_BYTES: usize = 0x400000; // 4 * 1024 * 1024 diff --git a/crates/interned/Cargo.toml b/crates/interned/Cargo.toml new file mode 100644 index 0000000..3f2d534 --- /dev/null +++ b/crates/interned/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "interned" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +# HashMap 实现 - 启用 nightly 优化 +hashbrown = { version = "0.16", default-features = false, features = [ + "nightly", # 🚀 SIMD 优化、unstable APIs + "raw-entry", # 用于高级 HashMap 操作 + "inline-more", # 更激进的内联优化 + #"allocator-api2", # 自定义分配器支持 +] } + +# RwLock 实现 - 启用性能优化 +parking_lot = { version = "0.12", features = [ + "nightly", # 🚀 unstable 优化 + "hardware-lock-elision", # Intel TSX 硬件锁优化(如果可用) + #"send_guard", # 允许跨线程传递 MutexGuard +] } + +# 哈希算法 - 启用硬件加速 +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", # 运行时随机种子(安全) + #"nightly-arm-aes", # 🚀 ARM AES 指令优化 +] } + +manually_init.workspace = true + +serde = { version = "1.0", optional = true } + +[features] +default = ["serde"] +nightly = [] +serde = ["dep:serde", "hashbrown/serde", "ahash/serde"] + +#[profile.release] +#opt-level = 3 +#lto = "fat" # 全局 LTO +#codegen-units = 1 # 单编译单元,最大优化 +#panic = "abort" # 减小二进制体积 +#strip = true # 移除符号信息 + +#[profile.bench] +#inherits = "release" + +# Nightly 特性门控 +[package.metadata.docs.rs] +rustc-args = ["--cfg", "docsrs"] +all-features = true diff --git a/crates/interned/src/arc_str.rs b/crates/interned/src/arc_str.rs new file mode 100644 index 0000000..cea3a0a --- /dev/null +++ b/crates/interned/src/arc_str.rs @@ -0,0 +1,1119 @@ +//! 引用计数的不可变字符串,支持全局字符串池复用 +//! +//! # 核心设计理念 +//! +//! `ArcStr` 通过全局字符串池实现内存去重,相同内容的字符串共享同一份内存。 +//! 这在大量重复字符串的场景下能显著降低内存使用,同时保持字符串操作的高性能。 +//! +//! # 架构概览 +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ 用户 API 层 │ +//! │ ArcStr::new() │ as_str() │ clone() │ Drop │ PartialEq... │ +//! ├─────────────────────────────────────────────────────────────────┤ +//! │ 全局字符串池 │ +//! │ RwLock> │ +//! │ 双重检查锁定 + 原子引用计数 │ +//! ├─────────────────────────────────────────────────────────────────┤ +//! │ 底层内存布局 │ +//! │ [hash:u64][count:AtomicUsize][len:usize][string_data...] │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # 性能特征 +//! +//! | 操作 | 时间复杂度 | 说明 | +//! |------|-----------|------| +//! | new() - 首次 | O(1) + 池插入 | 堆分配 + HashMap 插入 | +//! | new() - 命中 | O(1) | HashMap 查找 + 原子递增 | +//! | clone() | O(1) | 仅原子递增 | +//! | drop() | O(1) | 使用预存哈希快速删除 | +//! | as_str() | O(1) | 直接内存访问 | + +use core::{ + alloc::Layout, + borrow::Borrow, + cmp::Ordering, + fmt, + hash::{BuildHasherDefault, Hash, Hasher}, + hint, + marker::PhantomData, + ptr::NonNull, + str, + sync::atomic::{ + AtomicUsize, + Ordering::{Relaxed, Release}, + }, +}; +use hashbrown::{Equivalent, HashMap}; +use manually_init::ManuallyInit; +use parking_lot::RwLock; + +// ═══════════════════════════════════════════════════════════════════════════ +// 第一层:公共API与核心接口 +// ═══════════════════════════════════════════════════════════════════════════ + +/// 引用计数的不可变字符串,支持全局字符串池复用 +/// +/// # 设计目标 +/// +/// - **内存去重**:相同内容的字符串共享同一内存地址 +/// - **零拷贝克隆**:clone() 只涉及原子递增操作 +/// - **线程安全**:支持多线程环境下的安全使用 +/// - **高性能查找**:使用预计算哈希值优化池查找 +/// +/// # 使用示例 +/// +/// ```rust +/// use interned::ArcStr; +/// +/// let s1 = ArcStr::new("hello"); +/// let s2 = ArcStr::new("hello"); +/// +/// // 相同内容的字符串共享同一内存 +/// assert_eq!(s1.as_ptr(), s2.as_ptr()); +/// assert_eq!(s1.ref_count(), 2); +/// +/// // 零成本的字符串访问 +/// println!("{}", s1.as_str()); // "hello" +/// ``` +/// +/// # 内存安全 +/// +/// `ArcStr` 内部使用原子引用计数确保内存安全,无需担心悬挂指针或数据竞争。 +/// 当最后一个引用被释放时,字符串将自动从全局池中移除并释放内存。 +#[repr(transparent)] +pub struct ArcStr { + /// 指向 `ArcStrInner` 的非空指针 + /// + /// # 不变量 + /// - 指针始终有效,指向正确初始化的 `ArcStrInner` + /// - 引用计数至少为 1(在 drop 开始前) + /// - 字符串数据始终是有效的 UTF-8 + ptr: NonNull, + + /// 零大小标记,确保 `ArcStr` 拥有数据的所有权语义 + _marker: PhantomData, +} + +// SAFETY: ArcStr 使用原子引用计数,可以安全地跨线程传递和访问 +unsafe impl Send for ArcStr {} +unsafe impl Sync for ArcStr {} + +impl ArcStr { + /// 创建或复用字符串实例 + /// + /// 如果全局池中已存在相同内容的字符串,则复用现有实例并增加引用计数; + /// 否则创建新实例并加入池中。 + /// + /// # 并发策略 + /// + /// 使用双重检查锁定模式来平衡性能和正确性: + /// 1. **读锁快速路径**:大多数情况下只需要读锁即可找到现有字符串 + /// 2. **写锁创建路径**:仅在确实需要创建新字符串时获取写锁 + /// 3. **双重验证**:获取写锁后再次检查,防止并发创建重复实例 + /// + /// # 性能特征 + /// + /// - **池命中**:O(1) HashMap 查找 + 原子递增 + /// - **池缺失**:O(1) 内存分配 + O(1) HashMap 插入 + /// - **哈希计算**:使用 ahash 的高性能哈希算法 + /// + /// # Examples + /// + /// ```rust + /// let s1 = ArcStr::new("shared_content"); + /// let s2 = ArcStr::new("shared_content"); // 复用 s1 的内存 + /// assert_eq!(s1.as_ptr(), s2.as_ptr()); + /// ``` + pub fn new>(s: S) -> Self { + let string = s.as_ref(); + + // 阶段 0:预计算内容哈希 + // + // 这个哈希值在整个生命周期中会被多次使用: + // - 池查找时作为 HashMap 的键 + // - 存储在 ArcStrInner 中用于后续 drop 优化 + let hash = CONTENT_HASHER.hash_one(string); + + // ===== 阶段 1:读锁快速路径 ===== + // 大部分情况下字符串已经在池中,这个路径是最常见的 + { + let pool = ARC_STR_POOL.read(); + if let Some(existing) = Self::try_find_existing(&pool, hash, string) { + return existing; + } + // 读锁自动释放 + } + + // ===== 阶段 2:写锁创建路径 ===== + // 进入这里说明需要创建新的字符串实例 + let mut pool = ARC_STR_POOL.write(); + + // 双重检查:在获取写锁的过程中,其他线程可能已经创建了相同的字符串 + if let Some(existing) = Self::try_find_existing(&pool, hash, string) { + return existing; + } + + // 确认需要创建新实例:分配内存并初始化 + let layout = ArcStrInner::layout_for_string(string.len()); + + // SAFETY: layout_for_string 确保布局有效且大小合理 + let ptr = unsafe { + let alloc = alloc::alloc::alloc(layout) as *mut ArcStrInner; + + if alloc.is_null() { + hint::cold_path(); + alloc::alloc::handle_alloc_error(layout); + } + + let ptr = NonNull::new_unchecked(alloc); + ArcStrInner::write_with_string(ptr, string, hash); + ptr + }; + + // 将新创建的字符串加入全局池 + // 使用 from_key_hashed_nocheck 避免重复计算哈希 + pool.raw_entry_mut().from_key_hashed_nocheck(hash, string).insert(ThreadSafePtr(ptr), ()); + + Self { ptr, _marker: PhantomData } + } + + /// 获取字符串切片(零成本操作) + /// + /// 直接访问底层字符串数据,无任何额外开销。 + /// + /// # 性能 + /// + /// 这是一个 `const fn`,在编译时就能确定偏移量, + /// 运行时仅需要一次内存解引用。 + #[inline(always)] + pub const fn as_str(&self) -> &str { + // SAFETY: ptr 在 ArcStr 生命周期内始终指向有效的 ArcStrInner, + // 且字符串数据保证是有效的 UTF-8 + unsafe { self.ptr.as_ref().as_str() } + } + + /// 获取字符串的字节切片 + /// + /// 提供对底层字节数据的直接访问。 + #[inline(always)] + pub const fn as_bytes(&self) -> &[u8] { + // SAFETY: ptr 始终指向有效的 ArcStrInner + unsafe { self.ptr.as_ref().as_bytes() } + } + + /// 获取字符串长度(字节数) + #[inline(always)] + pub const fn len(&self) -> usize { + // SAFETY: ptr 始终指向有效的 ArcStrInner + unsafe { self.ptr.as_ref().string_len } + } + + /// 检查字符串是否为空 + #[inline(always)] + pub const fn is_empty(&self) -> bool { self.len() == 0 } + + /// 获取当前引用计数 + /// + /// 注意:由于并发访问,返回的值可能在返回后立即发生变化。 + /// 此方法主要用于调试和测试。 + #[inline(always)] + pub fn ref_count(&self) -> usize { + // SAFETY: ptr 始终指向有效的 ArcStrInner + unsafe { self.ptr.as_ref().strong_count() } + } + + /// 获取字符串数据的内存地址(用于调试和测试) + /// + /// 返回字符串内容的起始地址,可用于验证字符串是否共享内存。 + #[inline(always)] + pub const fn as_ptr(&self) -> *const u8 { + // SAFETY: ptr 始终指向有效的 ArcStrInner + unsafe { self.ptr.as_ref().string_ptr() } + } + + /// 内部辅助函数:在池中查找已存在的字符串 + /// + /// 这个函数被提取出来以消除读锁路径和写锁路径中的重复代码。 + /// 使用 hashbrown 的优化API来避免重复哈希计算。 + /// + /// # 参数 + /// + /// - `pool`: 字符串池的引用 + /// - `hash`: 预计算的字符串哈希值 + /// - `string`: 要查找的字符串内容 + /// + /// # 返回值 + /// + /// 如果找到匹配的字符串,返回增加引用计数后的 `ArcStr`;否则返回 `None`。 + #[inline(always)] + fn try_find_existing(pool: &PtrMap, hash: u64, string: &str) -> Option { + // 使用 hashbrown 的 from_key_hashed_nocheck API + // 这利用了 Equivalent trait 来进行高效比较 + let (ptr_ref, _) = pool.raw_entry().from_key_hashed_nocheck(hash, string)?; + let ptr = ptr_ref.0; + + // 找到匹配的字符串,增加其引用计数 + // SAFETY: 池中的指针始终有效,且引用计数操作是原子的 + unsafe { ptr.as_ref().inc_strong() }; + + Some(Self { ptr, _marker: PhantomData }) + } +} + +impl Clone for ArcStr { + /// 克隆字符串引用(仅增加引用计数) + /// + /// 这是一个极其轻量的操作,只涉及一次原子递增。 + /// 不会复制字符串内容,新的 `ArcStr` 与原实例共享相同的底层内存。 + /// + /// # 性能 + /// + /// 时间复杂度:O(1) - 单次原子操作 + /// 空间复杂度:O(1) - 无额外内存分配 + #[inline] + fn clone(&self) -> Self { + // SAFETY: ptr 在当前 ArcStr 生命周期内有效 + unsafe { self.ptr.as_ref().inc_strong() } + Self { ptr: self.ptr, _marker: PhantomData } + } +} + +impl Drop for ArcStr { + /// 释放字符串引用 + /// + /// 递减引用计数,如果这是最后一个引用,则从全局池中移除并释放内存。 + /// + /// # 并发处理 + /// + /// 由于多个线程可能同时释放同一字符串的引用,这里使用了谨慎的双重检查: + /// 1. 原子递减引用计数 + /// 2. 如果计数变为0,获取池的写锁 + /// 3. 再次检查引用计数(防止并发的clone操作) + /// 4. 确认后从池中移除并释放内存 + /// + /// # 性能优化 + /// + /// 使用预存储的哈希值进行 O(1) 的池查找和删除,避免重新计算哈希。 + fn drop(&mut self) { + // SAFETY: ptr 在 drop 开始时仍然有效 + unsafe { + let inner = self.ptr.as_ref(); + + // 原子递减引用计数 + if !inner.dec_strong() { + // 不是最后一个引用,直接返回 + return; + } + + // 这是最后一个引用,需要清理资源 + let mut pool = ARC_STR_POOL.write(); + + // 双重检查引用计数 + // 在获取写锁期间,其他线程可能clone了这个字符串 + if inner.strong_count() != 0 { + return; + } + + // 确认是最后一个引用,执行清理 + let hash = inner.hash; + let entry = pool.raw_entry_mut().from_hash(hash, |k| { + // 使用指针相等比较,这是绝对的 O(1) 操作 + k.0 == self.ptr + }); + + if let hashbrown::hash_map::RawEntryMut::Occupied(e) = entry { + e.remove(); + } + + // 释放底层内存 + let layout = ArcStrInner::layout_for_string_unchecked(inner.string_len); + alloc::alloc::dealloc(self.ptr.cast().as_ptr(), layout); + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 第二层:标准库集成 +// ═══════════════════════════════════════════════════════════════════════════ + +/// # 基础 Trait 实现 +/// +/// 这些实现确保 `ArcStr` 能够与 Rust 的标准库类型无缝集成, +/// 提供符合直觉的比较、格式化和访问接口。 + +impl PartialEq for ArcStr { + /// 基于指针的快速相等比较 + /// + /// # 优化原理 + /// + /// 由于字符串池保证相同内容的字符串具有相同的内存地址, + /// 我们可以通过比较指针来快速判断字符串是否相等, + /// 避免逐字节的内容比较。 + /// + /// 这使得相等比较成为 O(1) 操作,而不是 O(n)。 + #[inline] + fn eq(&self, other: &Self) -> bool { self.ptr == other.ptr } +} + +impl Eq for ArcStr {} + +impl PartialOrd for ArcStr { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for ArcStr { + /// 基于字符串内容的字典序比较 + /// + /// 注意:这里必须比较内容而不是指针,因为指针地址与字典序无关。 + #[inline] + fn cmp(&self, other: &Self) -> Ordering { self.as_str().cmp(other.as_str()) } +} + +impl Hash for ArcStr { + /// 基于字符串内容的哈希 + /// + /// 虽然内部存储了预计算的哈希值,但这里重新计算以确保 + /// 与 `&str` 和 `String` 的哈希值保持一致。 + #[inline] + fn hash(&self, state: &mut H) { state.write_str(self.as_str()) } +} + +impl fmt::Display for ArcStr { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(self.as_str(), f) } +} + +impl fmt::Debug for ArcStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(self.as_str(), f) } +} + +impl const AsRef for ArcStr { + #[inline] + fn as_ref(&self) -> &str { self.as_str() } +} + +impl const AsRef<[u8]> for ArcStr { + #[inline] + fn as_ref(&self) -> &[u8] { self.as_bytes() } +} + +impl const Borrow for ArcStr { + #[inline] + fn borrow(&self) -> &str { self.as_str() } +} + +impl const core::ops::Deref for ArcStr { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { self.as_str() } +} + +/// # 与其他字符串类型的互操作性 +/// +/// 这些实现使得 `ArcStr` 可以与 Rust 生态系统中的各种字符串类型 +/// 进行直接比较,提供良好的开发体验。 + +impl const PartialEq for ArcStr { + #[inline] + fn eq(&self, other: &str) -> bool { self.as_str() == other } +} + +impl const PartialEq<&str> for ArcStr { + #[inline] + fn eq(&self, other: &&str) -> bool { self.as_str() == *other } +} + +impl const PartialEq for str { + #[inline] + fn eq(&self, other: &ArcStr) -> bool { self == other.as_str() } +} + +impl const PartialEq for &str { + #[inline] + fn eq(&self, other: &ArcStr) -> bool { *self == other.as_str() } +} + +impl const PartialEq for ArcStr { + #[inline] + fn eq(&self, other: &String) -> bool { self.as_str() == other.as_str() } +} + +impl const PartialEq for String { + #[inline] + fn eq(&self, other: &ArcStr) -> bool { self.as_str() == other.as_str() } +} + +impl PartialOrd for ArcStr { + #[inline] + fn partial_cmp(&self, other: &str) -> Option { Some(self.as_str().cmp(other)) } +} + +impl PartialOrd for ArcStr { + #[inline] + fn partial_cmp(&self, other: &String) -> Option { + Some(self.as_str().cmp(other.as_str())) + } +} + +/// # 类型转换实现 +/// +/// 提供从各种字符串类型到 `ArcStr` 的便捷转换, +/// 以及从 `ArcStr` 到其他类型的转换。 + +impl<'a> From<&'a str> for ArcStr { + #[inline] + fn from(s: &'a str) -> Self { Self::new(s) } +} + +impl<'a> From<&'a String> for ArcStr { + #[inline] + fn from(s: &'a String) -> Self { Self::new(s) } +} + +impl From for ArcStr { + #[inline] + fn from(s: String) -> Self { Self::new(s) } +} + +impl<'a> From> for ArcStr { + #[inline] + fn from(cow: alloc::borrow::Cow<'a, str>) -> Self { Self::new(cow) } +} + +impl From> for ArcStr { + #[inline] + fn from(s: alloc::boxed::Box) -> Self { Self::new(s) } +} + +impl From for String { + #[inline] + fn from(s: ArcStr) -> Self { s.as_str().to_owned() } +} + +impl From for alloc::boxed::Box { + #[inline] + fn from(s: ArcStr) -> Self { s.as_str().into() } +} + +impl str::FromStr for ArcStr { + type Err = core::convert::Infallible; + + #[inline] + fn from_str(s: &str) -> Result { Ok(Self::new(s)) } +} + +/// # Serde 序列化支持 +/// +/// 条件编译的 Serde 支持,使 `ArcStr` 可以参与序列化/反序列化流程。 +/// 序列化时输出字符串内容,反序列化时重新建立池化引用。 +#[cfg(feature = "serde")] +mod serde_impls { + use super::*; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + impl Serialize for ArcStr { + #[inline] + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + self.as_str().serialize(serializer) + } + } + + impl<'de> Deserialize<'de> for ArcStr { + #[inline] + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + String::deserialize(deserializer).map(ArcStr::new) + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 第三层:核心实现机制 +// ═══════════════════════════════════════════════════════════════════════════ + +/// # 内存布局与数据结构设计 +/// +/// 这个模块包含了 `ArcStr` 的底层数据结构定义和内存布局管理。 +/// 理解这部分有助于深入了解性能优化的原理。 + +/// 字符串内容的内部表示(DST 头部) +/// +/// # 内存布局设计 +/// +/// 使用 `#[repr(C)]` 确保内存布局稳定,字符串数据紧跟在结构体后面: +/// +/// ```text +/// 64位系统内存布局: +/// ┌────────────────────┬──────────────────────────────────────────┐ +/// │ 字段 │ 大小与对齐 │ +/// ├────────────────────┼──────────────────────────────────────────┤ +/// │ hash: u64 │ 8字节, 8字节对齐 (offset: 0) │ +/// │ count: AtomicUsize │ 8字节, 8字节对齐 (offset: 8) │ +/// │ string_len: usize │ 8字节, 8字节对齐 (offset: 16) │ +/// ├────────────────────┼──────────────────────────────────────────┤ +/// │ [字符串数据] │ string_len字节, 1字节对齐 (offset: 24) │ +/// └────────────────────┴──────────────────────────────────────────┘ +/// 总头部大小:24字节 +/// +/// 32位系统内存布局: +/// ┌────────────────────┬──────────────────────────────────────────┐ +/// │ hash: u64 │ 8字节, 8字节对齐 (offset: 0) │ +/// │ count: AtomicUsize │ 4字节, 4字节对齐 (offset: 8) │ +/// │ string_len: usize │ 4字节, 4字节对齐 (offset: 12) │ +/// ├────────────────────┼──────────────────────────────────────────┤ +/// │ [字符串数据] │ string_len字节, 1字节对齐 (offset: 16) │ +/// └────────────────────┴──────────────────────────────────────────┘ +/// 总头部大小:16字节 +/// ``` +/// +/// # 设计考量 +/// +/// 1. **哈希值前置**:将 `hash` 放在首位确保在32位系统上的正确对齐 +/// 2. **原子计数器**:使用 `AtomicUsize` 保证并发安全的引用计数 +/// 3. **长度缓存**:预存字符串长度避免重复计算 +/// 4. **DST布局**:字符串数据直接跟随结构体,减少间接访问 +#[repr(C)] +struct ArcStrInner { + /// 预计算的内容哈希值 + /// + /// 这个哈希值在多个场景中被复用: + /// - 全局池的HashMap键 + /// - Drop时的快速查找 + /// - 避免重复哈希计算的性能优化 + hash: u64, + + /// 原子引用计数 + /// + /// 使用原生原子类型确保最佳性能。 + /// 计数范围:[1, isize::MAX],超出时触发abort。 + count: AtomicUsize, + + /// 字符串的字节长度(UTF-8编码) + /// + /// 预存长度避免在每次访问时扫描字符串。 + /// 不包含NUL终止符。 + string_len: usize, + // 注意:字符串数据紧跟在这个结构体后面, + // 通过 layout_for_string() 计算的布局来确保正确的内存分配 +} + +impl ArcStrInner { + /// 字符串长度的上限 + /// + /// 计算公式:`isize::MAX - sizeof(ArcStrInner)` + /// 这确保总分配大小不会溢出有符号整数范围。 + const MAX_LEN: usize = isize::MAX as usize - core::mem::size_of::(); + + /// 获取字符串数据的起始地址 + /// + /// # Safety + /// + /// - `self` 必须是指向有效 `ArcStrInner` 的指针 + /// - 必须确保字符串数据已经被正确初始化 + /// - 调用者负责确保返回的指针在使用期间保持有效 + #[inline(always)] + const unsafe fn string_ptr(&self) -> *const u8 { + // SAFETY: repr(C) 保证字符串数据位于结构体末尾的固定偏移处 + (self as *const Self).add(1).cast() + } + + /// 获取字符串的字节切片 + /// + /// # Safety + /// + /// - `self` 必须是指向有效 `ArcStrInner` 的指针 + /// - 字符串数据必须已经被正确初始化 + /// - `string_len` 必须准确反映实际字符串长度 + /// - 字符串数据必须在返回的切片生命周期内保持有效 + #[inline(always)] + const unsafe fn as_bytes(&self) -> &[u8] { + let ptr = self.string_ptr(); + // SAFETY: 调用者保证 ptr 指向有效的 string_len 字节数据 + core::slice::from_raw_parts(ptr, self.string_len) + } + + /// 获取字符串切片引用 + /// + /// # Safety + /// + /// - `self` 必须是指向有效 `ArcStrInner` 的指针 + /// - 字符串数据必须是有效的 UTF-8 编码 + /// - `string_len` 必须准确反映实际字符串长度 + /// - 字符串数据必须在返回的切片生命周期内保持有效 + #[inline(always)] + const unsafe fn as_str(&self) -> &str { + // SAFETY: 调用者保证字符串数据是有效的 UTF-8 + core::str::from_utf8_unchecked(self.as_bytes()) + } + + /// 计算存储指定长度字符串所需的内存布局 + /// + /// 这个函数计算出正确的内存大小和对齐要求, + /// 确保结构体和字符串数据都能正确对齐。 + /// + /// # Panics + /// + /// 如果 `string_len > Self::MAX_LEN`,函数会panic。 + /// 这是为了防止整数溢出和无效的内存布局。 + /// + /// # Examples + /// + /// ```rust + /// let layout = ArcStrInner::layout_for_string(5); // "hello" + /// assert!(layout.size() >= 24 + 5); // 64位系统 + /// ``` + fn layout_for_string(string_len: usize) -> Layout { + if string_len > Self::MAX_LEN { + hint::cold_path(); + panic!("字符串过长: {} 字节 (最大支持: {})", string_len, Self::MAX_LEN); + } + + // SAFETY: 长度检查通过,布局计算是安全的 + unsafe { Self::layout_for_string_unchecked(string_len) } + } + + /// 计算存储指定长度字符串所需的内存布局(不检查长度) + /// + /// # Safety + /// + /// 调用者必须保证 `string_len <= Self::MAX_LEN` + const unsafe fn layout_for_string_unchecked(string_len: usize) -> Layout { + let header = Layout::new::(); + let string_data = Layout::from_size_align_unchecked(string_len, 1); + // SAFETY: 长度已经过检查,布局计算不会溢出 + let (combined, _offset) = header.extend(string_data).unwrap_unchecked(); + combined.pad_to_align() + } + + /// 在指定内存位置初始化 `ArcStrInner` 并写入字符串数据 + /// + /// 这是一个低级函数,负责设置完整的DST结构: + /// 1. 初始化头部字段 + /// 2. 复制字符串数据到紧邻的内存 + /// + /// # Safety + /// + /// - `ptr` 必须指向通过 `layout_for_string(string.len())` 分配的有效内存 + /// - 内存必须正确对齐且大小足够 + /// - `string` 必须是有效的 UTF-8 字符串 + /// - 调用者负责最终释放这块内存 + /// - 在调用此函数后,调用者必须确保引用计数正确管理 + const unsafe fn write_with_string(ptr: NonNull, string: &str, hash: u64) { + let inner = ptr.as_ptr(); + + // 第一步:初始化头部结构体 + // SAFETY: ptr 指向有效的已分配内存,大小足够容纳 Self + core::ptr::write( + inner, + Self { hash, count: AtomicUsize::new(1), string_len: string.len() }, + ); + + // 第二步:复制字符串数据到紧邻头部后的内存 + // SAFETY: + // - string_ptr() 计算出的地址位于已分配内存范围内 + // - string.len() 与分配时的长度一致 + // - string.as_ptr() 指向有效的 UTF-8 数据 + let string_ptr = (*inner).string_ptr() as *mut u8; + core::ptr::copy_nonoverlapping(string.as_ptr(), string_ptr, string.len()); + } + + /// 原子递增引用计数 + /// + /// # 溢出处理 + /// + /// 如果引用计数超过 `isize::MAX`,函数会立即abort程序。 + /// 这是一个极端情况,在正常使用中几乎不可能发生。 + /// + /// # Safety + /// + /// - `self` 必须指向有效的 `ArcStrInner` + /// - 当前引用计数必须至少为 1(即存在有效引用) + #[inline] + unsafe fn inc_strong(&self) { + let old_count = self.count.fetch_add(1, Relaxed); + + // 防止引用计数溢出 - 这是一个安全检查 + if old_count > isize::MAX as usize { + hint::cold_path(); + // 溢出是内存安全问题,必须立即终止程序 + core::intrinsics::abort(); + } + } + + /// 原子递减引用计数 + /// + /// 使用 Release 内存序确保所有之前的修改对后续的操作可见。 + /// 这对于安全的内存回收至关重要。 + /// + /// # Safety + /// + /// - `self` 必须指向有效的 `ArcStrInner` + /// - 当前引用计数必须至少为 1 + /// + /// # 返回值 + /// + /// 如果这是最后一个引用(计数变为 0),返回 `true`;否则返回 `false`。 + #[inline] + unsafe fn dec_strong(&self) -> bool { + // Release ordering: 确保之前的所有修改对后续的内存释放操作可见 + self.count.fetch_sub(1, Release) == 1 + } + + /// 获取当前引用计数的快照 + /// + /// 注意:由于并发性,返回值可能在返回后立即过时。 + /// 此方法主要用于调试和测试目的。 + #[inline] + fn strong_count(&self) -> usize { self.count.load(Relaxed) } +} + +/// # 全局字符串池的设计与实现 +/// +/// 全局池是整个系统的核心,负责去重和生命周期管理。 + +/// 线程安全的内部指针包装 +/// +/// 这个类型解决了在 `HashMap` 中存储 `NonNull` 的问题: +/// - 提供必要的 trait 实现(Hash, PartialEq, Send, Sync) +/// - 封装指针的线程安全语义 +/// - 支持基于内容的查找(通过 Equivalent trait) +/// +/// # 线程安全性 +/// +/// 虽然包装了裸指针,但 `ThreadSafePtr` 是线程安全的,因为: +/// - 指向的 `ArcStrInner` 是不可变的(除了原子引用计数) +/// - 引用计数使用原子操作 +/// - 生命周期由全局池管理,确保指针有效性 +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +struct ThreadSafePtr(NonNull); + +// SAFETY: ArcStrInner 内容不可变且使用原子引用计数,可以安全地跨线程访问 +unsafe impl Send for ThreadSafePtr {} +unsafe impl Sync for ThreadSafePtr {} + +impl const core::ops::Deref for ThreadSafePtr { + type Target = NonNull; + + #[inline(always)] + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl Hash for ThreadSafePtr { + /// 使用预存储的哈希值 + /// + /// 这是一个关键优化:我们不重新计算字符串内容的哈希, + /// 而是直接使用存储在 `ArcStrInner` 中的预计算值。 + /// 配合 `IdentityHasher` 使用,避免任何额外的哈希计算。 + #[inline] + fn hash(&self, state: &mut H) { + // SAFETY: ThreadSafePtr 保证指针在池生命周期内始终有效 + unsafe { + let inner = self.0.as_ref(); + state.write_u64(inner.hash) + } + } +} + +impl PartialEq for ThreadSafePtr { + /// 基于指针相等的比较 + /// + /// 这是池去重机制的核心:只有指向同一内存地址的指针 + /// 才被认为是"相同"的池条目。内容相同但地址不同的字符串 + /// 在池中是不应该同时存在的。 + #[inline] + fn eq(&self, other: &Self) -> bool { self.0 == other.0 } +} + +impl Eq for ThreadSafePtr {} + +impl Equivalent for str { + /// 支持用 `&str` 在 `HashSet` 中查找 + /// + /// 这个实现使得我们可以用字符串内容来查找池中的条目, + /// 而不需要先构造一个 `ThreadSafePtr`。 + /// + /// # 性能优化 + /// + /// 先比较字符串长度(单个 usize 比较),只有长度相等时 + /// 才进行内容比较(潜在的 memcmp)。这避免了在长度不等时 + /// 构造 fat pointer 的开销。 + #[inline] + fn equivalent(&self, key: &ThreadSafePtr) -> bool { + // SAFETY: 池中的 ThreadSafePtr 保证指向有效的 ArcStrInner + unsafe { + let inner = key.0.as_ref(); + + // 优化:先比较长度(O(1)),避免不必要的内容比较 + if inner.string_len != self.len() { + return false; + } + + // 长度相等时进行内容比较 + inner.as_str() == self + } + } +} + +/// # 哈希算法选择与池类型定义 + +/// 透传哈希器,用于全局池内部 +/// +/// 由于我们在 `ArcStrInner` 中预存了哈希值,池内部的 HashMap +/// 不需要重新计算哈希。`IdentityHasher` 直接透传 u64 值。 +/// +/// # 工作原理 +/// +/// 1. `ThreadSafePtr::hash()` 调用 `hasher.write_u64(stored_hash)` +/// 2. `IdentityHasher::write_u64()` 直接存储这个值 +/// 3. `IdentityHasher::finish()` 返回存储的值 +/// 4. HashMap 使用这个哈希值进行桶分配和查找 +/// +/// 这避免了重复的哈希计算,将池操作的哈希开销降到最低。 +#[derive(Default, Clone, Copy)] +struct IdentityHasher(u64); + +impl Hasher for IdentityHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("IdentityHasher 只应该用于 write_u64"); + } + + #[inline(always)] + fn write_u64(&mut self, id: u64) { self.0 = id; } + + #[inline(always)] + fn finish(&self) -> u64 { self.0 } +} + +/// 池的类型别名,简化代码 +type PoolHasher = BuildHasherDefault; +type PtrMap = HashMap; + +/// 内容哈希计算器 +/// +/// 使用 ahash 的高性能随机哈希算法来计算字符串内容的哈希值。 +/// 这个哈希值会被存储在 `ArcStrInner` 中,用于整个生命周期。 +/// +/// # 为什么使用 ahash? +/// +/// - 高性能:比标准库的 DefaultHasher 更快 +/// - 安全性:抗哈希洪水攻击 +/// - 质量:分布均匀,减少哈希冲突 +static CONTENT_HASHER: ManuallyInit = ManuallyInit::new(); + +/// 全局字符串池 +/// +/// 使用 `RwLock` 实现高并发的字符串池: +/// - **读锁**:多个线程可以同时查找现有字符串 +/// - **写锁**:创建新字符串时需要独占访问 +/// - **容量预分配**:避免初期的频繁扩容 +/// +/// # 并发模式 +/// +/// ```text +/// 并发读取(常见情况): +/// Thread A: read_lock() -> 查找 "hello" -> 找到 -> 返回 +/// Thread B: read_lock() -> 查找 "world" -> 找到 -> 返回 +/// Thread C: read_lock() -> 查找 "hello" -> 找到 -> 返回 +/// +/// 并发写入(偶尔发生): +/// Thread D: write_lock() -> 查找 "new" -> 未找到 -> 创建 -> 插入 -> 返回 +/// ``` +static ARC_STR_POOL: ManuallyInit> = ManuallyInit::new(); + +/// 初始化全局字符串池 +/// +/// 这个函数必须在使用 `ArcStr` 之前调用,通常在程序启动时完成。 +/// 初始化过程包括: +/// 1. 创建内容哈希计算器 +/// 2. 创建空的字符串池(预分配128个条目的容量) +/// +/// # 线程安全性 +/// +/// 虽然这个函数本身不是线程安全的,但它应该在单线程环境下 +/// (如 main 函数开始或静态初始化时)被调用一次。 +#[inline(always)] +pub(crate) fn __init() { + CONTENT_HASHER.init(ahash::RandomState::new()); + ARC_STR_POOL.init(RwLock::new(PtrMap::with_capacity_and_hasher(128, PoolHasher::default()))); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 第四层:性能优化实现 +// ═══════════════════════════════════════════════════════════════════════════ + +/// # 内存管理优化策略 +/// +/// 这个模块包含了各种底层的性能优化实现, +/// 包括内存布局计算、分配策略和并发优化。 + +// (这里是性能关键的内部函数实现,已经在上面的代码中体现了) + +/// # 并发控制优化 +/// +/// 双重检查锁定模式的详细实现分析: +/// +/// ```text +/// 时间线示例: +/// T1: Thread A 调用 ArcStr::new("test") +/// T2: Thread A 获取读锁,查找池,未找到 +/// T3: Thread A 释放读锁 +/// T4: Thread B 调用 ArcStr::new("test") +/// T5: Thread B 获取读锁,查找池,未找到 +/// T6: Thread B 释放读锁 +/// T7: Thread A 获取写锁 +/// T8: Thread A 再次查找(双重检查),确认未找到 +/// T9: Thread A 创建新实例,插入池 +/// T10: Thread A 释放写锁 +/// T11: Thread B 等待写锁... +/// T12: Thread B 获取写锁 +/// T13: Thread B 再次查找(双重检查),找到! +/// T14: Thread B 增加引用计数,释放写锁 +/// ``` + +// ═══════════════════════════════════════════════════════════════════════════ +// 第五层:测试与工具 +// ═══════════════════════════════════════════════════════════════════════════ + +/// # 测试辅助工具 +/// +/// 这些函数仅在测试环境中可用,用于检查池的内部状态 +/// 和进行隔离测试。 + +#[cfg(test)] +pub(crate) fn pool_stats() -> (usize, usize) { + let pool = ARC_STR_POOL.read(); + (pool.len(), pool.capacity()) +} + +#[cfg(test)] +pub(crate) fn clear_pool_for_test() { + use std::{thread, time::Duration}; + // 短暂等待确保其他线程完成操作 + thread::sleep(Duration::from_millis(10)); + ARC_STR_POOL.write().clear(); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{thread, time::Duration}; + + /// 运行隔离的测试,确保测试间不会相互影响 + fn run_isolated_test(f: F) { + clear_pool_for_test(); + f(); + clear_pool_for_test(); + } + + #[test] + fn test_basic_functionality() { + run_isolated_test(|| { + let s1 = ArcStr::new("hello"); + let s2 = ArcStr::new("hello"); + let s3 = ArcStr::new("world"); + + // 验证相等性和指针共享 + assert_eq!(s1, s2); + assert_ne!(s1, s3); + assert_eq!(s1.ptr, s2.ptr); // 相同内容共享内存 + assert_ne!(s1.ptr, s3.ptr); // 不同内容不同内存 + + // 验证基础操作 + assert_eq!(s1.as_str(), "hello"); + assert_eq!(s1.len(), 5); + assert!(!s1.is_empty()); + + // 验证池状态 + let (count, _) = pool_stats(); + assert_eq!(count, 2); // "hello" 和 "world" + }); + } + + #[test] + fn test_reference_counting() { + run_isolated_test(|| { + let s1 = ArcStr::new("test"); + assert_eq!(s1.ref_count(), 1); + + let s2 = s1.clone(); + assert_eq!(s1.ref_count(), 2); + assert_eq!(s2.ref_count(), 2); + assert_eq!(s1.ptr, s2.ptr); + + drop(s2); + assert_eq!(s1.ref_count(), 1); + + drop(s1); + // 等待 drop 完成 + thread::sleep(Duration::from_millis(5)); + assert_eq!(pool_stats().0, 0); + }); + } + + #[test] + fn test_pool_reuse() { + run_isolated_test(|| { + let s1 = ArcStr::new("reuse_test"); + let s2 = ArcStr::new("reuse_test"); + + assert_eq!(s1.ptr, s2.ptr); + assert_eq!(s1.ref_count(), 2); + assert_eq!(pool_stats().0, 1); // 只有一个池条目 + }); + } + + #[test] + fn test_thread_safety() { + run_isolated_test(|| { + use alloc::sync::Arc; + + let s = Arc::new(ArcStr::new("shared")); + let handles: Vec<_> = (0..10) + .map(|_| { + let s_clone = Arc::clone(&s); + thread::spawn(move || { + let local = ArcStr::new("shared"); + assert_eq!(*s_clone, local); + assert_eq!(s_clone.ptr, local.ptr); + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + }); + } + + #[test] + fn test_empty_string() { + run_isolated_test(|| { + let empty = ArcStr::new(""); + assert!(empty.is_empty()); + assert_eq!(empty.len(), 0); + assert_eq!(empty.as_str(), ""); + }); + } + + #[test] + fn test_from_implementations() { + run_isolated_test(|| { + use alloc::borrow::Cow; + + let s1 = ArcStr::from("from_str"); + let s2 = ArcStr::from(String::from("from_string")); + let s3 = ArcStr::from(Cow::Borrowed("from_cow")); + + assert_eq!(s1.as_str(), "from_str"); + assert_eq!(s2.as_str(), "from_string"); + assert_eq!(s3.as_str(), "from_cow"); + }); + } +} diff --git a/crates/interned/src/lib.rs b/crates/interned/src/lib.rs new file mode 100644 index 0000000..adc2032 --- /dev/null +++ b/crates/interned/src/lib.rs @@ -0,0 +1,29 @@ +#![feature(cold_path)] +#![feature(const_trait_impl)] +#![feature(const_convert)] +#![feature(const_cmp)] +#![feature(const_default)] +#![feature(hasher_prefixfree_extras)] +#![feature(const_result_unwrap_unchecked)] +#![feature(core_intrinsics)] +#![allow(internal_features)] +#![allow(unsafe_op_in_unsafe_fn)] +#![allow(non_camel_case_types)] +#![warn(clippy::all)] +#![warn(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +extern crate alloc; + +mod arc_str; +mod str; + +pub use arc_str::ArcStr; +pub use str::Str; + +pub type InternedStr = ArcStr; +pub type StaticStr = &'static str; +pub type string = Str; + +#[inline] +pub fn init() { arc_str::__init() } diff --git a/crates/interned/src/str.rs b/crates/interned/src/str.rs new file mode 100644 index 0000000..94a044d --- /dev/null +++ b/crates/interned/src/str.rs @@ -0,0 +1,1248 @@ +//! 组合字符串类型,统一编译期和运行时字符串 +//! +//! # 设计理念 +//! +//! Rust 中常见两类字符串: +//! - **字面量** (`&'static str`): 编译期确定,零成本,永不释放 +//! - **动态字符串** (`String`, `ArcStr`): 运行时构造,需要内存管理 +//! +//! `Str` 通过枚举将两者统一,提供一致的 API,同时保留各自的性能优势。 +//! +//! # 内存布局 +//! +//! ```text +//! enum Str { +//! Static(&'static str) // 16 bytes (fat pointer) +//! Counted(ArcStr) // 8 bytes (NonNull) +//! } +//! +//! 总大小: 17-24 bytes (取决于编译器优化) +//! - Discriminant: 1 byte +//! - Padding: 0-7 bytes +//! - Data: 16 bytes (最大变体) +//! ``` +//! +//! # 性能对比 +//! +//! | 操作 | Static | Counted | +//! |------|--------|---------| +//! | 创建 | 0 ns | ~100 ns (首次) / ~20 ns (池命中) | +//! | Clone | ~1 ns | ~5 ns (atomic inc) | +//! | Drop | 0 ns | ~5 ns (atomic dec) + 可能的清理 | +//! | as_str() | 0 ns | 0 ns (直接访问) | +//! | len() | 0 ns | 0 ns (直接读字段) | +//! +//! # 使用场景 +//! +//! ## ✅ 使用 Static 变体 +//! +//! ```rust +//! use interned::Str; +//! +//! // 常量表 +//! static KEYWORDS: &[Str] = &[ +//! Str::from_static("fn"), +//! Str::from_static("let"), +//! Str::from_static("match"), +//! ]; +//! +//! // 编译期字符串 +//! const ERROR_MSG: Str = Str::from_static("error occurred"); +//! ``` +//! +//! ## ✅ 使用 Counted 变体 +//! +//! ```rust +//! use interned::Str; +//! +//! // 运行时字符串(去重) +//! let user_input = Str::new(get_user_input()); +//! +//! // 跨线程共享 +//! let shared = Str::new("config"); +//! std::thread::spawn(move || { +//! process(shared); +//! }); +//! ``` +//! +//! ## ⚠️ 常见陷阱 +//! +//! ```rust +//! use interned::Str; +//! +//! // ❌ 字面量不要用 new() +//! let bad = Str::new("literal"); // 创建 Counted,进入池 +//! +//! // ✅ 应该用 from_static +//! const GOOD: Str = Str::from_static("literal"); // Static 变体,零成本 +//! ``` + +use super::arc_str::ArcStr; +use alloc::borrow::Cow; +use core::{ + cmp::Ordering, + hash::{Hash, Hasher}, +}; + +// ============================================================================ +// Core Type Definition +// ============================================================================ + +/// 组合字符串类型,支持编译期字面量和运行时引用计数字符串 +/// +/// # Variants +/// +/// ## Static +/// +/// - 包装 `&'static str` +/// - 零分配成本 +/// - 零运行时开销 +/// - Clone 是简单的指针复制 +/// - 永不释放 +/// +/// ## Counted +/// +/// - 包装 `ArcStr` +/// - 堆分配,通过全局字符串池去重 +/// - 原子引用计数管理 +/// - 线程安全共享 +/// - 最后一个引用释放时回收 +/// +/// # Method Shadowing +/// +/// `Str` 提供了与 `str` 同名的方法(如 `len()`, `is_empty()`), +/// 这些方法会覆盖(shadow)`Deref` 提供的版本,以便: +/// +/// - 对 `Static` 变体:直接访问 `&'static str` +/// - 对 `Counted` 变体:使用 `ArcStr` 的优化实现(直接读取内部字段) +/// +/// ```rust +/// use interned::Str; +/// +/// let s = Str::new("hello"); +/// // 调用 Str::len(),而不是 ::deref().len() +/// // 对于 Counted 变体,这避免了构造 &str 的开销 +/// assert_eq!(s.len(), 5); +/// ``` +/// +/// # Examples +/// +/// ```rust +/// use interned::Str; +/// +/// // 编译期字符串 +/// let s1 = Str::from_static("hello"); +/// assert!(s1.is_static()); +/// assert_eq!(s1.ref_count(), None); +/// +/// // 运行时字符串 +/// let s2 = Str::new("world"); +/// assert!(!s2.is_static()); +/// assert_eq!(s2.ref_count(), Some(1)); +/// +/// // 统一接口 +/// assert_eq!(s1.len(), 5); +/// assert_eq!(s2.len(), 5); +/// ``` +/// +/// # Thread Safety +/// +/// `Str` 是 `Send + Sync`,可以安全地在线程间传递: +/// +/// ```rust +/// use interned::Str; +/// use std::thread; +/// +/// let s = Str::new("shared"); +/// thread::spawn(move || { +/// println!("{}", s); +/// }); +/// ``` +#[derive(Clone)] +pub enum Str { + /// 编译期字符串字面量 + /// + /// - 零成本创建和访问 + /// - Clone 是指针复制(~1ns) + /// - 永不释放内存 + /// - 适合常量表和配置 + Static(&'static str), + + /// 运行时引用计数字符串 + /// + /// - 通过字符串池自动去重 + /// - 原子引用计数(线程安全) + /// - Clone 增加引用计数(~5ns) + /// - 最后一个引用释放时回收 + Counted(ArcStr), +} + +// SAFETY: 两个变体都是 Send + Sync +unsafe impl Send for Str {} +unsafe impl Sync for Str {} + +// ============================================================================ +// Construction +// ============================================================================ + +impl Str { + /// 创建静态字符串变体(编译期字面量) + /// + /// 这是创建零成本字符串的**推荐方式**。 + /// + /// # Const Context + /// + /// 此函数是 `const fn`,可在编译期求值: + /// + /// ```rust + /// use interned::Str; + /// + /// const GREETING: Str = Str::from_static("Hello"); + /// + /// static KEYWORDS: &[Str] = &[ + /// Str::from_static("fn"), + /// Str::from_static("let"), + /// ]; + /// ``` + /// + /// # Performance + /// + /// - 编译期:零成本(字符串嵌入二进制) + /// - 运行期:零成本(只是指针) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::from_static("constant"); + /// assert!(s.is_static()); + /// assert_eq!(s.as_static(), Some("constant")); + /// assert_eq!(s.ref_count(), None); + /// ``` + #[inline] + pub const fn from_static(s: &'static str) -> Self { Self::Static(s) } + + /// 创建或复用运行时字符串 + /// + /// 字符串会进入全局字符串池,相同内容的字符串会复用同一内存。 + /// + /// # Performance + /// + /// - **首次创建**:堆分配 + HashMap 插入 ≈ 100-200ns + /// - **池命中**:HashMap 查找 + 引用计数递增 ≈ 10-20ns + /// + /// # Thread Safety + /// + /// 字符串池使用 `RwLock` 保护,支持并发访问: + /// - 多个线程可以同时读取(查找已有字符串) + /// - 创建新字符串时需要独占写锁 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::new("dynamic"); + /// let s2 = Str::new("dynamic"); + /// + /// // 两个字符串共享同一内存 + /// assert_eq!(s1.ref_count(), s2.ref_count()); + /// assert!(s1.ref_count().unwrap() >= 2); + /// ``` + /// + /// # Use Cases + /// + /// ```rust + /// use interned::Str; + /// + /// // ✅ 编译器:标识符去重 + /// let ident = Str::new(token.text); + /// + /// // ✅ 配置系统:键名复用 + /// let key = Str::new("database.host"); + /// + /// // ✅ 跨线程共享 + /// let shared = Str::new("data"); + /// std::thread::spawn(move || { + /// process(shared); + /// }); + /// # fn token() -> Token { Token { text: "x" } } + /// # struct Token { text: &'static str } + /// # fn process(_: Str) {} + /// ``` + #[inline] + pub fn new>(s: S) -> Self { Self::Counted(ArcStr::new(s)) } + + /// 检查是否为 Static 变体 + /// + /// 用于判断字符串是否为编译期字面量。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("literal"); + /// let s2 = Str::new("dynamic"); + /// + /// assert!(s1.is_static()); + /// assert!(!s2.is_static()); + /// ``` + /// + /// # Use Cases + /// + /// ```rust + /// use interned::Str; + /// + /// fn optimize_for_static(s: &Str) { + /// if s.is_static() { + /// // 可以安全地转换为 &'static str + /// let static_str = s.as_static().unwrap(); + /// register_constant(static_str); + /// } + /// } + /// # fn register_constant(_: &'static str) {} + /// ``` + #[inline] + pub const fn is_static(&self) -> bool { matches!(self, Self::Static(_)) } + + /// 获取引用计数 + /// + /// - **Static 变体**:返回 `None`(无引用计数概念) + /// - **Counted 变体**:返回 `Some(count)` + /// + /// # Note + /// + /// 由于并发访问,返回的值可能在读取后立即过时。 + /// 主要用于调试和测试。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("static"); + /// let s2 = Str::new("counted"); + /// let s3 = s2.clone(); + /// + /// assert_eq!(s1.ref_count(), None); + /// assert_eq!(s2.ref_count(), Some(2)); + /// assert_eq!(s3.ref_count(), Some(2)); + /// ``` + #[inline] + pub fn ref_count(&self) -> Option { + match self { + Self::Static(_) => None, + Self::Counted(arc) => Some(arc.ref_count()), + } + } + + /// 尝试获取静态字符串引用 + /// + /// 只有 Static 变体会返回 `Some`。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("literal"); + /// let s2 = Str::new("dynamic"); + /// + /// assert_eq!(s1.as_static(), Some("literal")); + /// assert_eq!(s2.as_static(), None); + /// ``` + /// + /// # Use Cases + /// + /// 某些 API 需要 `&'static str`: + /// + /// ```rust + /// use interned::Str; + /// + /// fn register_global(name: &'static str) { + /// // 注册需要静态生命周期的字符串 + /// # drop(name); + /// } + /// + /// let s = Str::from_static("name"); + /// if let Some(static_str) = s.as_static() { + /// register_global(static_str); + /// } else { + /// // Counted 变体无法转换为 'static + /// eprintln!("warning: not a static string"); + /// } + /// ``` + #[inline] + pub const fn as_static(&self) -> Option<&'static str> { + match self { + Self::Static(s) => Some(*s), + Self::Counted(_) => None, + } + } + + /// 尝试获取内部 `ArcStr` 的引用 + /// + /// 只有 Counted 变体会返回 `Some`。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("literal"); + /// let s2 = Str::new("dynamic"); + /// + /// assert!(s1.as_arc_str().is_none()); + /// assert!(s2.as_arc_str().is_some()); + /// ``` + #[inline] + pub const fn as_arc_str(&self) -> Option<&ArcStr> { + match self { + Self::Static(_) => None, + Self::Counted(arc) => Some(arc), + } + } + + /// 尝试将 Counted 变体转换为 `ArcStr` + /// + /// - **Counted**:返回 `Some(ArcStr)`,零成本转换 + /// - **Static**:返回 `None` + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::new("counted"); + /// let s2 = Str::from_static("static"); + /// + /// assert!(s1.into_arc_str().is_some()); + /// assert!(s2.into_arc_str().is_none()); + /// ``` + #[inline] + pub fn into_arc_str(self) -> Option { + match self { + Self::Static(_) => None, + Self::Counted(arc) => Some(arc), + } + } +} + +// ============================================================================ +// Optimized str Methods (Method Shadowing) +// ============================================================================ + +impl Str { + /// 获取字符串切片 + /// + /// 这个方法覆盖了 `Deref` 提供的 `as_str()`,以便: + /// - 对 `Static` 变体:直接返回 `&'static str` + /// - 对 `Counted` 变体:使用 `ArcStr::as_str()` 的优化实现 + /// + /// # Performance + /// + /// - **Static**:零成本(只是返回指针) + /// - **Counted**:零成本(直接访问内部字段) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("hello"); + /// assert_eq!(s.as_str(), "hello"); + /// ``` + #[inline(always)] + pub const fn as_str(&self) -> &str { + match self { + Self::Static(s) => s, + Self::Counted(arc) => arc.as_str(), + } + } + + /// 获取字符串的字节切片 + /// + /// 覆盖 `Deref` 版本以传播 `ArcStr::as_bytes()` 的优化。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("hello"); + /// assert_eq!(s.as_bytes(), b"hello"); + /// ``` + #[inline(always)] + pub const fn as_bytes(&self) -> &[u8] { + match self { + Self::Static(s) => s.as_bytes(), + Self::Counted(arc) => arc.as_bytes(), + } + } + + /// 获取字符串长度(字节数) + /// + /// 覆盖 `Deref` 版本以传播 `ArcStr::len()` 的优化(直接读取字段)。 + /// + /// # Performance + /// + /// - **Static**:读取 fat pointer 的 len 字段 + /// - **Counted**:读取 `ArcStrInner::string_len` 字段(无需构造 `&str`) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("hello"); + /// assert_eq!(s.len(), 5); + /// ``` + #[inline(always)] + pub const fn len(&self) -> usize { + match self { + Self::Static(s) => s.len(), + Self::Counted(arc) => arc.len(), + } + } + + /// 检查字符串是否为空 + /// + /// 覆盖 `Deref` 版本以传播 `ArcStr::is_empty()` 的优化。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::new(""); + /// let s2 = Str::new("not empty"); + /// + /// assert!(s1.is_empty()); + /// assert!(!s2.is_empty()); + /// ``` + #[inline(always)] + pub const fn is_empty(&self) -> bool { + match self { + Self::Static(s) => s.is_empty(), + Self::Counted(arc) => arc.is_empty(), + } + } + + /// 获取内部指针(用于调试和测试) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("ptr"); + /// let ptr = s.as_ptr(); + /// assert!(!ptr.is_null()); + /// ``` + #[inline(always)] + pub const fn as_ptr(&self) -> *const u8 { + match self { + Self::Static(s) => s.as_ptr(), + Self::Counted(arc) => arc.as_ptr(), + } + } +} + +// ============================================================================ +// From Conversions +// ============================================================================ + +impl const From<&'static str> for Str { + /// 从字面量创建 Static 变体 + /// + /// ⚠️ **注意**:只有真正的 `&'static str` 才会自动推断为 Static。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// // ✅ 字面量自动推断为 Static + /// let s: Str = "literal".into(); + /// assert!(s.is_static()); + /// + /// // ❌ 但这不会工作(编译错误): + /// // let owned = String::from("not static"); + /// // let s: Str = owned.as_str().into(); // 生命周期不是 'static + /// ``` + #[inline] + fn from(s: &'static str) -> Self { Self::Static(s) } +} + +impl From for Str { + /// 从 `String` 创建 Counted 变体 + /// + /// 字符串会进入字符串池,如果已存在相同内容则复用。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s: Str = String::from("owned").into(); + /// assert!(!s.is_static()); + /// assert_eq!(s.as_str(), "owned"); + /// ``` + #[inline] + fn from(s: String) -> Self { Self::Counted(ArcStr::from(s)) } +} + +impl From<&String> for Str { + /// 从 `&String` 创建 Counted 变体 + #[inline] + fn from(s: &String) -> Self { Self::Counted(ArcStr::from(s)) } +} + +impl From for Str { + /// 从 `ArcStr` 创建 Counted 变体 + /// + /// 直接包装,不会额外增加引用计数。 + /// + /// # Examples + /// + /// ```rust + /// use interned::{Str, ArcStr}; + /// + /// let arc = ArcStr::new("shared"); + /// let count_before = arc.ref_count(); + /// + /// let s: Str = arc.into(); + /// assert_eq!(s.ref_count(), Some(count_before)); + /// ``` + #[inline] + fn from(arc: ArcStr) -> Self { Self::Counted(arc) } +} + +impl<'a> From> for Str { + /// 从 `Cow` 创建 Counted 变体 + /// + /// 无论 Cow 是 Borrowed 还是 Owned,都会进入字符串池。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::borrow::Cow; + /// + /// let borrowed: Cow = Cow::Borrowed("borrowed"); + /// let owned: Cow = Cow::Owned(String::from("owned")); + /// + /// let s1: Str = borrowed.into(); + /// let s2: Str = owned.into(); + /// + /// assert!(!s1.is_static()); + /// assert!(!s2.is_static()); + /// ``` + #[inline] + fn from(cow: Cow<'a, str>) -> Self { Self::Counted(ArcStr::from(cow)) } +} + +impl From> for Str { + /// 从 `Box` 创建 Counted 变体 + #[inline] + fn from(s: alloc::boxed::Box) -> Self { Self::Counted(ArcStr::from(s)) } +} + +impl From for String { + /// 转换为 `String`(总是需要分配) + /// + /// # Performance + /// + /// 无论哪个变体,都需要分配并复制字符串内容。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("to_string"); + /// let string: String = s.into(); + /// assert_eq!(string, "to_string"); + /// ``` + #[inline] + fn from(s: Str) -> Self { s.as_str().to_owned() } +} + +impl From for alloc::boxed::Box { + /// 转换为 `Box`(需要分配) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::new("boxed"); + /// let boxed: Box = s.into(); + /// assert_eq!(&*boxed, "boxed"); + /// ``` + #[inline] + fn from(s: Str) -> Self { s.as_str().into() } +} + +impl<'a> From for Cow<'a, str> { + /// 转换为 `Cow` + /// + /// - **Static 变体**:转换为 `Cow::Borrowed`(零成本) + /// - **Counted 变体**:转换为 `Cow::Owned`(需要分配) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::borrow::Cow; + /// + /// let s1 = Str::from_static("static"); + /// let cow1: Cow = s1.into(); + /// assert!(matches!(cow1, Cow::Borrowed(_))); + /// + /// let s2 = Str::new("counted"); + /// let cow2: Cow = s2.into(); + /// assert!(matches!(cow2, Cow::Owned(_))); + /// ``` + #[inline] + fn from(s: Str) -> Self { + match s { + Str::Static(s) => Cow::Borrowed(s), + Str::Counted(arc) => Cow::Owned(arc.into()), + } + } +} + +impl<'a> const From<&'a Str> for Cow<'a, str> { + /// 转换为 `Cow::Borrowed`(零成本) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::borrow::Cow; + /// + /// let s = Str::from_static("cow"); + /// let cow: Cow = (&s).into(); + /// + /// assert!(matches!(cow, Cow::Borrowed(_))); + /// assert_eq!(cow, "cow"); + /// ``` + #[inline] + fn from(s: &'a Str) -> Self { Cow::Borrowed(s.as_str()) } +} + +impl core::str::FromStr for Str { + type Err = core::convert::Infallible; + + /// 从字符串解析(总是成功,创建 Counted 变体) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::str::FromStr; + /// + /// let s = Str::from_str("parsed").unwrap(); + /// assert!(!s.is_static()); + /// assert_eq!(s.as_str(), "parsed"); + /// ``` + #[inline] + fn from_str(s: &str) -> Result { Ok(Self::new(s)) } +} + +// ============================================================================ +// Comparison & Hashing +// ============================================================================ + +impl PartialEq for Str { + /// 比较字符串内容 + /// + /// # Optimization + /// + /// - **Counted vs Counted**:首先比较指针(O(1)),然后比较内容 + /// - **Static vs Static**:直接比较内容(编译器可能优化为指针比较) + /// - **Static vs Counted**:必须比较内容 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("test"); + /// let s2 = Str::new("test"); + /// + /// assert_eq!(s1, s2); // 内容相同即相等 + /// ``` + #[inline] + fn eq(&self, other: &Self) -> bool { + match (self, other) { + // Counted vs Counted: 利用 ArcStr 的指针比较优化 + (Self::Counted(a), Self::Counted(b)) => a == b, + // 其他情况:比较字符串内容 + _ => self.as_str() == other.as_str(), + } + } +} + +impl Eq for Str {} + +impl const PartialEq for Str { + #[inline] + fn eq(&self, other: &str) -> bool { self.as_str() == other } +} + +impl const PartialEq<&str> for Str { + #[inline] + fn eq(&self, other: &&str) -> bool { self.as_str() == *other } +} + +impl const PartialEq for Str { + #[inline] + fn eq(&self, other: &String) -> bool { self.as_str() == other.as_str() } +} + +impl const PartialEq for str { + #[inline] + fn eq(&self, other: &Str) -> bool { self == other.as_str() } +} + +impl const PartialEq for &str { + #[inline] + fn eq(&self, other: &Str) -> bool { *self == other.as_str() } +} + +impl const PartialEq for String { + #[inline] + fn eq(&self, other: &Str) -> bool { self.as_str() == other.as_str() } +} + +impl PartialEq for Str { + /// 优化的 `Str` 与 `ArcStr` 比较 + /// + /// 如果 `Str` 是 Counted 变体,使用指针比较(快速路径)。 + /// + /// # Examples + /// + /// ```rust + /// use interned::{Str, ArcStr}; + /// + /// let arc = ArcStr::new("test"); + /// let s1 = Str::from(arc.clone()); + /// let s2 = Str::from_static("test"); + /// + /// assert_eq!(s1, arc); // 指针比较 + /// assert_eq!(s2, arc); // 内容比较 + /// ``` + #[inline] + fn eq(&self, other: &ArcStr) -> bool { + match self { + Self::Counted(arc) => arc == other, + Self::Static(s) => *s == other.as_str(), + } + } +} + +impl PartialEq for ArcStr { + #[inline] + fn eq(&self, other: &Str) -> bool { other == self } +} + +impl Hash for Str { + /// 基于字符串内容的哈希,与变体类型无关 + /// + /// 这确保了 `Static("a")` 和 `Counted(ArcStr::new("a"))` + /// 有相同的哈希值,可以在 `HashMap` 中作为相同的 key。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::collections::HashMap; + /// + /// let mut map = HashMap::new(); + /// let s1 = Str::from_static("key"); + /// let s2 = Str::new("key"); + /// + /// map.insert(s1, "value"); + /// assert_eq!(map.get(&s2), Some(&"value")); // s2 可以找到 s1 插入的值 + /// ``` + #[inline] + fn hash(&self, state: &mut H) { state.write_str(self.as_str()) } +} + +// ============================================================================ +// Ordering +// ============================================================================ + +impl PartialOrd for Str { + /// 字典序比较 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let a = Str::from_static("apple"); + /// let b = Str::new("banana"); + /// + /// assert!(a < b); + /// assert!(b > a); + /// ``` + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for Str { + /// 字典序比较(总序) + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let mut strs = vec![ + /// Str::new("cherry"), + /// Str::from_static("apple"), + /// Str::new("banana"), + /// ]; + /// + /// strs.sort(); + /// + /// assert_eq!(strs[0].as_str(), "apple"); + /// assert_eq!(strs[1].as_str(), "banana"); + /// assert_eq!(strs[2].as_str(), "cherry"); + /// ``` + #[inline] + fn cmp(&self, other: &Self) -> Ordering { self.as_str().cmp(other.as_str()) } +} + +// ============================================================================ +// Deref & AsRef +// ============================================================================ + +impl core::ops::Deref for Str { + type Target = str; + + /// 支持自动解引用为 `&str` + /// + /// 这允许直接调用 `str` 的所有方法(如 `starts_with()`, `contains()` 等)。 + /// + /// ⚠️ **Note**: 常用方法(如 `len()`, `is_empty()`)已被 `Str` 的同名方法覆盖, + /// 以便传播 `ArcStr` 的优化。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::from_static("deref"); + /// + /// // 可以直接调用 str 的方法 + /// assert!(s.starts_with("de")); + /// assert!(s.contains("ref")); + /// assert_eq!(s.to_uppercase(), "DEREF"); + /// ``` + #[inline] + fn deref(&self) -> &Self::Target { self.as_str() } +} + +impl const AsRef for Str { + #[inline] + fn as_ref(&self) -> &str { self.as_str() } +} + +impl const AsRef<[u8]> for Str { + #[inline] + fn as_ref(&self) -> &[u8] { self.as_bytes() } +} + +impl const core::borrow::Borrow for Str { + /// 支持在 `HashMap` 中使用 `&str` 查找 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// use std::collections::HashMap; + /// + /// let mut map = HashMap::new(); + /// map.insert(Str::new("key"), "value"); + /// + /// // 可以使用 &str 查找 + /// assert_eq!(map.get("key"), Some(&"value")); + /// ``` + #[inline] + fn borrow(&self) -> &str { self.as_str() } +} + +// ============================================================================ +// Display & Debug +// ============================================================================ + +impl core::fmt::Display for Str { + /// 输出字符串内容 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::from_static("display"); + /// assert_eq!(format!("{}", s), "display"); + /// ``` + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { f.write_str(self.as_str()) } +} + +impl core::fmt::Debug for Str { + /// 调试输出,显示变体类型和内容 + /// + /// # Output Format + /// + /// - **Static**: `Str::Static("content")` + /// - **Counted**: `Str::Counted("content", refcount=N)` + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s1 = Str::from_static("debug"); + /// let s2 = Str::new("counted"); + /// + /// println!("{:?}", s1); // Str::Static("debug") + /// println!("{:?}", s2); // Str::Counted("counted", refcount=1) + /// ``` + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Static(s) => f.debug_tuple("Str::Static").field(s).finish(), + Self::Counted(arc) => f + .debug_tuple("Str::Counted") + .field(&arc.as_str()) + .field(&format_args!("refcount={}", arc.ref_count())) + .finish(), + } + } +} + +// ============================================================================ +// Default +// ============================================================================ + +impl const Default for Str { + /// 返回空字符串的 Static 变体 + /// + /// 这是零成本的,不会分配任何内存。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::default(); + /// assert!(s.is_empty()); + /// assert!(s.is_static()); + /// assert_eq!(s.as_str(), ""); + /// ``` + #[inline] + fn default() -> Self { Self::Static(Default::default()) } +} + +// ============================================================================ +// Serde Support +// ============================================================================ + +#[cfg(feature = "serde")] +mod serde_impls { + use super::*; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + impl Serialize for Str { + /// 序列化为普通字符串,丢失变体信息 + /// + /// ⚠️ **注意**:反序列化后总是 Counted 变体。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let s = Str::from_static("serialize"); + /// let json = serde_json::to_string(&s).unwrap(); + /// assert_eq!(json, r#""serialize""#); + /// ``` + #[inline] + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + self.as_str().serialize(serializer) + } + } + + impl<'de> Deserialize<'de> for Str { + /// 反序列化为 Counted 变体 + /// + /// ⚠️ **注意**:无法恢复 Static 变体,因为反序列化的字符串 + /// 不具有 `'static` 生命周期。 + /// + /// # Examples + /// + /// ```rust + /// use interned::Str; + /// + /// let json = r#""deserialize""#; + /// let s: Str = serde_json::from_str(json).unwrap(); + /// + /// assert!(!s.is_static()); // 总是 Counted + /// assert_eq!(s.as_str(), "deserialize"); + /// ``` + #[inline] + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + String::deserialize(deserializer).map(Str::from) + } + } +} + +// ============================================================================ +// Testing +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_method_shadowing() { + let s1 = Str::from_static("hello"); + let s2 = Str::new("world"); + + // 验证调用的是覆盖版本(通过编译即可) + assert_eq!(s1.len(), 5); + assert_eq!(s2.len(), 5); + assert!(!s1.is_empty()); + assert_eq!(s1.as_bytes(), b"hello"); + assert_eq!(s1.as_str(), "hello"); + } + + #[test] + fn test_static_vs_counted() { + let s1 = Str::from_static("hello"); + let s2 = Str::new("hello"); + + assert!(s1.is_static()); + assert!(!s2.is_static()); + assert_eq!(s1.ref_count(), None); + assert!(s2.ref_count().is_some()); + assert_eq!(s1, s2); + } + + #[test] + fn test_arcstr_conversions() { + let arc = ArcStr::new("test"); + let count_before = arc.ref_count(); + + // ArcStr -> Str + let s: Str = arc.clone().into(); + assert!(!s.is_static()); + assert_eq!(s.ref_count(), Some(count_before + 1)); + + // Str -> Option + let arc_back = s.into_arc_str(); + assert!(arc_back.is_some()); + assert_eq!(arc_back.unwrap(), arc); + } + + #[test] + fn test_arcstr_equality() { + let arc = ArcStr::new("same"); + let s1 = Str::from(arc.clone()); + let s2 = Str::from_static("same"); + + // Counted vs ArcStr: 指针比较 + assert_eq!(s1, arc); + + // Static vs ArcStr: 内容比较 + assert_eq!(s2, arc); + } + + #[test] + fn test_default() { + let s = Str::default(); + assert!(s.is_empty()); + assert!(s.is_static()); + assert_eq!(s.len(), 0); + } + + #[test] + fn test_const_construction() { + const GREETING: Str = Str::from_static("Hello"); + static KEYWORDS: &[Str] = + &[Str::from_static("fn"), Str::from_static("let"), Str::from_static("match")]; + + assert!(GREETING.is_static()); + assert_eq!(KEYWORDS.len(), 3); + assert!(KEYWORDS[0].is_static()); + } + + #[test] + fn test_deref() { + let s = Str::from_static("deref"); + + // 通过 Deref 访问 str 的方法 + assert!(s.starts_with("de")); + assert!(s.contains("ref")); + assert_eq!(s.to_uppercase(), "DEREF"); + } + + #[test] + fn test_ordering() { + let mut strs = vec![Str::new("cherry"), Str::from_static("apple"), Str::new("banana")]; + + strs.sort(); + + assert_eq!(strs[0], "apple"); + assert_eq!(strs[1], "banana"); + assert_eq!(strs[2], "cherry"); + } + + #[test] + fn test_conversions() { + // From implementations + let s1: Str = "literal".into(); + let s2: Str = String::from("owned").into(); + let s3: Str = ArcStr::new("arc").into(); + + assert!(s1.is_static()); + assert!(!s2.is_static()); + assert!(!s3.is_static()); + + // Into implementations + let string: String = s2.clone().into(); + assert_eq!(string, "owned"); + + let boxed: alloc::boxed::Box = s3.into(); + assert_eq!(&*boxed, "arc"); + } + + #[test] + fn test_hash_consistency() { + use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + }; + + let s1 = Str::from_static("test"); + let s2 = Str::new("test"); + + let mut h1 = DefaultHasher::new(); + let mut h2 = DefaultHasher::new(); + + s1.hash(&mut h1); + s2.hash(&mut h2); + + assert_eq!(h1.finish(), h2.finish()); + } +} diff --git a/crates/manually_init/Cargo.toml b/crates/manually_init/Cargo.toml new file mode 100644 index 0000000..ad3696d --- /dev/null +++ b/crates/manually_init/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "manually_init" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] + +[features] +default = [] +sync = [] + +[package.metadata.docs.rs] +all-features = true diff --git a/crates/manually_init/src/lib.rs b/crates/manually_init/src/lib.rs new file mode 100644 index 0000000..240549d --- /dev/null +++ b/crates/manually_init/src/lib.rs @@ -0,0 +1,363 @@ +//! # ManuallyInit - Zero-Cost Manual Memory Initialization +//! +//! A minimalist unsafe abstraction for manual memory initialization, designed for experts who need +//! direct memory control without runtime overhead. +//! +//! ## Design Philosophy +//! +//! `ManuallyInit` is not a safe alternative to `std::sync::OnceLock` or `std::cell::OnceCell`. +//! It is a deliberate choice for scenarios where: +//! +//! - You need zero runtime overhead +//! - You have full control over initialization timing and access patterns +//! - You explicitly choose to manage safety invariants manually +//! - You come from C/C++ and want familiar, direct memory semantics +//! +//! ## The Core Contract +//! +//! **By using this crate, you accept complete responsibility for:** +//! +//! 1. **Initialization state tracking** - You must know when values are initialized +//! 2. **Thread safety** - You must ensure no data races in concurrent environments +//! 3. **Aliasing rules** - You must uphold Rust's borrowing rules manually +//! +//! The library provides ergonomic APIs by marking methods as safe, but they are semantically unsafe. +//! This is a deliberate design choice to reduce syntax noise in controlled unsafe contexts. +//! +//! ## Primary Use Pattern: Single-Thread Init, Multi-Thread Read +//! +//! The most common and recommended pattern: +//! ```rust,no_run +//! use manually_init::ManuallyInit; +//! +//! static GLOBAL: ManuallyInit = ManuallyInit::new(); +//! +//! // In main thread during startup +//! fn initialize() { +//! GLOBAL.init(Config::load()); +//! } +//! +//! // In any thread after initialization +//! fn use_config() { +//! let config = GLOBAL.get(); +//! // Read-only access is safe +//! } +//! ``` +//! +//! ## Choosing the Right Tool +//! +//! | Strategy | Category | Choose For | What You Get | Cost | +//! |----------|----------|------------|--------------|------| +//! | **`const` item** | Compile-time | True constants that can be inlined | Zero runtime cost | Must be const-evaluable | +//! | **`static` item** | Read-only global | Simple immutable data with fixed address | Zero read cost | Must be `'static` and const | +//! | **`std::sync::LazyLock`** | Lazy immutable | **Default choice** for lazy statics | Automatic thread-safe init | Atomic check per access | +//! | **`std::sync::OnceLock`** | Lazy immutable | Manual control over init timing | Thread-safe one-time init | Atomic read per access | +//! | **`std::sync::Mutex`** | Mutable global | Simple exclusive access | One thread at a time | OS-level lock per access | +//! | **`std::sync::RwLock`** | Mutable global | Multiple readers, single writer | Concurrent reads | Reader/writer lock overhead | +//! | **`core::sync::atomic`** | Lock-free | Primitive types without locks | Wait-free operations | Memory ordering complexity | +//! | **`thread_local!`** | Thread-local | Per-thread state | No synchronization needed | Per-thread initialization | +//! | **`parking_lot::Mutex`** | Mutable global | Faster alternative to std::sync::Mutex | Smaller, faster locks | No poisoning, custom features | +//! | **`parking_lot::RwLock`** | Mutable global | Faster alternative to std::sync::RwLock | Better performance | No poisoning, custom features | +//! | **`parking_lot::OnceCell`** | Lazy immutable | Backport/alternative to std version | Same as std, more features | Similar to std version | +//! | **`once_cell::sync::OnceCell`** | Lazy immutable | Pre-1.70 Rust compatibility | Same as std version | Similar to std version | +//! | **`once_cell::sync::Lazy`** | Lazy immutable | Pre-1.80 Rust compatibility | Same as std version | Similar to std version | +//! | **`lazy_static::lazy_static!`** | Lazy immutable | Macro-based lazy statics | Convenient syntax | Extra dependency, macro overhead | +//! | **`crossbeam::atomic::AtomicCell`** | Lock-free | Any `Copy` type atomically | Lock-free for small types | CAS loop overhead | +//! | **`dashmap::DashMap`** | Concurrent map | High-throughput key-value store | Per-shard locking | Higher memory usage | +//! | **`tokio::sync::Mutex`** | Async mutable | Async-aware exclusive access | Works across .await | Async runtime overhead | +//! | **`tokio::sync::RwLock`** | Async mutable | Async-aware read/write lock | Works across .await | Async runtime overhead | +//! | **`tokio::sync::OnceCell`** | Async init | Initialization in async context | Async-aware safety | Async runtime overhead | +//! | **`static_cell::StaticCell`** | Single-thread | no_std mutable statics | Safe mutable statics | Single-threaded only | +//! | **`conquer_once::OnceCell`** | Lock-free init | Wait-free reads after init | no_std compatible | Complex implementation | +//! | **`core::mem::MaybeUninit`** | Unsafe primitive | Building custom abstractions | Maximum control | `unsafe` for every operation | +//! | **`ManuallyInit`** | Unsafe ergonomic | Zero-cost with external safety proof | Safe-looking API, zero overhead | You handle all safety | +//! +//! ## When to Use ManuallyInit +//! +//! Choose `ManuallyInit` only when: +//! - You need absolute zero overhead (no runtime state tracking) +//! - You have complete control over access patterns +//! - You're interfacing with C/C++ code that expects raw memory +//! - You're implementing a custom synchronization primitive +//! - You can prove safety through external invariants (e.g., init-before-threads pattern) +//! +//! ## Types Requiring Extra Care +//! +//! When using `ManuallyInit`, be especially careful with: +//! +//! | Type Category | Examples | Risk | Mitigation | +//! |--------------|----------|------|------------| +//! | **Heap Owners** | `String`, `Vec`, `Box` | Memory leaks on re-init | Call `take()` before re-init | +//! | **Reference Counted** | `Rc`, `Arc` | Reference leaks | Proper cleanup required | +//! | **Interior Mutability** | `Cell`, `RefCell` | Complex aliasing rules | Avoid or handle carefully | +//! | **Async Types** | `Future`, `Waker` | Complex lifetime requirements | Not recommended | +//! +//! ## Feature Flags +//! +//! - `sync` - Enables `Sync` implementation. By enabling this feature, you explicitly accept +//! responsibility for preventing data races in concurrent access patterns. + +#![no_std] +#![cfg_attr( + feature = "sync", + doc = "**⚠️ The `sync` feature is enabled. You are responsible for thread safety.**" +)] + +use core::cell::UnsafeCell; +use core::mem::MaybeUninit; + +/// A zero-cost wrapper for manually managed initialization. +/// +/// This type provides direct memory access with no runtime checks. It's designed for +/// experts who need precise control over initialization and memory layout. +/// +/// # Safety Invariants You Must Uphold +/// +/// 1. **Never access uninitialized memory** - Calling `get()`, `deref()`, etc. on an +/// uninitialized instance is undefined behavior. +/// +/// 2. **Track initialization state** - The type does not track whether it's initialized. +/// This is entirely your responsibility. +/// +/// 3. **Handle concurrent access** - If `sync` feature is enabled, you must ensure: +/// - Initialization happens before any concurrent access +/// - No concurrent mutations occur +/// - Memory barriers are properly established +/// +/// 4. **Manage memory lifecycle** - `init()` overwrites without dropping. For types +/// that own heap memory, this causes leaks. +/// +/// # Example: Global Configuration +/// +/// ```rust,no_run +/// use manually_init::ManuallyInit; +/// +/// #[derive(Copy, Clone)] +/// struct Config { +/// max_connections: usize, +/// timeout_ms: u64, +/// } +/// +/// static CONFIG: ManuallyInit = ManuallyInit::new(); +/// +/// // Called once during startup +/// fn initialize_config() { +/// CONFIG.init(Config { +/// max_connections: 100, +/// timeout_ms: 5000, +/// }); +/// } +/// +/// // Called from any thread after initialization +/// fn get_timeout() -> u64 { +/// CONFIG.get().timeout_ms +/// } +/// ``` +/// +/// # Example: FFI Pattern +/// +/// ```rust,no_run +/// use manually_init::ManuallyInit; +/// use core::ffi::c_void; +/// +/// static FFI_CONTEXT: ManuallyInit<*mut c_void> = ManuallyInit::new(); +/// +/// extern "C" fn init_library(ctx: *mut c_void) { +/// FFI_CONTEXT.init(ctx); +/// } +/// +/// extern "C" fn use_library() { +/// let ctx = *FFI_CONTEXT.get(); +/// // Use ctx with FFI functions +/// } +/// ``` +#[repr(transparent)] +pub struct ManuallyInit { + value: UnsafeCell>, +} + +impl ManuallyInit { + /// Creates a new uninitialized instance. + /// + /// The memory is not initialized. You must call `init()` before any access. + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// static DATA: ManuallyInit = ManuallyInit::new(); + /// ``` + #[inline] + #[must_use] + #[allow(clippy::new_without_default)] + pub const fn new() -> ManuallyInit { + ManuallyInit { value: UnsafeCell::new(MaybeUninit::uninit()) } + } + + /// Creates a new instance initialized with the given value. + /// + /// The instance is immediately ready for use. + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new_with(42); + /// assert_eq!(*cell.get(), 42); + /// ``` + #[inline] + #[must_use] + pub const fn new_with(value: T) -> ManuallyInit { + ManuallyInit { value: UnsafeCell::new(MaybeUninit::new(value)) } + } + + /// Initializes or overwrites the value. + /// + /// **Critical**: This method does NOT drop the old value. For types that own + /// heap memory (`String`, `Vec`, `Box`, etc.), this causes memory leaks. + /// + /// # Memory Safety + /// + /// - For `Copy` types: Safe to call repeatedly + /// - For heap-owning types: Call `take()` first or track initialization manually + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new(); + /// cell.init(42); + /// + /// // Safe for Copy types + /// cell.init(100); + /// assert_eq!(*cell.get(), 100); + /// ``` + #[inline] + pub const fn init(&self, value: T) { + unsafe { (&mut *self.value.get()).write(value) }; + } + + /// Gets a shared reference to the value. + /// + /// # Safety Contract + /// + /// You must ensure the value is initialized. Calling this on uninitialized + /// memory is undefined behavior. + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new_with(42); + /// let value: &i32 = cell.get(); + /// assert_eq!(*value, 42); + /// ``` + #[inline] + pub const fn get(&self) -> &T { + unsafe { (&*self.value.get()).assume_init_ref() } + } + + /// Gets a raw mutable pointer to the value. + /// + /// This returns a raw pointer to avoid aliasing rule violations. You are + /// responsible for ensuring no aliasing occurs when dereferencing. + /// + /// # Safety Contract + /// + /// - The value must be initialized before dereferencing + /// - You must ensure no other references exist when creating `&mut T` + /// - You must follow Rust's aliasing rules manually + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new_with(42); + /// let ptr = cell.get_ptr(); + /// + /// // You must ensure no other references exist + /// unsafe { + /// *ptr = 100; + /// } + /// + /// assert_eq!(*cell.get(), 100); + /// ``` + #[inline] + pub const fn get_ptr(&self) -> *mut T { + unsafe { (&mut *self.value.get()).as_mut_ptr() } + } + + /// Consumes the cell and returns the inner value. + /// + /// # Safety Contract + /// + /// The value must be initialized. Calling this on uninitialized memory + /// is undefined behavior. + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new_with(String::from("hello")); + /// let s = cell.into_inner(); + /// assert_eq!(s, "hello"); + /// ``` + #[inline] + pub const fn into_inner(self) -> T { + unsafe { self.value.into_inner().assume_init() } + } + + /// Takes the value out, leaving the cell uninitialized. + /// + /// After calling this method, the cell is uninitialized. You must not + /// access it until calling `init()` again. + /// + /// # Safety Contract + /// + /// The value must be initialized. Calling this on uninitialized memory + /// is undefined behavior. + /// + /// # Example + /// ```rust + /// use manually_init::ManuallyInit; + /// + /// let cell = ManuallyInit::new_with(String::from("hello")); + /// let s = cell.take(); + /// assert_eq!(s, "hello"); + /// // cell is now uninitialized! + /// + /// // Must reinitialize before access + /// cell.init(String::from("world")); + /// ``` + #[inline] + pub const fn take(&self) -> T { + unsafe { + let slot = &mut *self.value.get(); + let value = slot.assume_init_read(); + *slot = MaybeUninit::uninit(); + value + } + } +} + +unsafe impl Send for ManuallyInit {} + +#[cfg(feature = "sync")] +unsafe impl Sync for ManuallyInit {} + +impl ::core::ops::Deref for ManuallyInit { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + self.get() + } +} + +impl ::core::ops::DerefMut for ManuallyInit { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + // Safe because we have &mut self, ensuring exclusive access + unsafe { (&mut *self.value.get()).assume_init_mut() } + } +} diff --git a/crates/rep_move/Cargo.toml b/crates/rep_move/Cargo.toml new file mode 100644 index 0000000..17e4fb5 --- /dev/null +++ b/crates/rep_move/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "rep_move" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] diff --git a/crates/rep_move/src/lib.rs b/crates/rep_move/src/lib.rs new file mode 100644 index 0000000..6898753 --- /dev/null +++ b/crates/rep_move/src/lib.rs @@ -0,0 +1,246 @@ +//! Iterator that yields N-1 replications followed by the original value. +//! +//! Optimized for expensive-to-clone types by moving the original on the last iteration. + +#![no_std] +#![feature(const_destruct)] +#![feature(const_trait_impl)] + +use core::{fmt, iter::FusedIterator, marker::Destruct}; + +/// Replication strategy for `RepMove`. +pub trait Replicator { + /// Creates a replica with mutable access to the remaining count. + fn replicate(&mut self, source: &T, remaining: &mut usize) -> T; +} + +// Blanket impl for simple replicators +impl Replicator for F +where F: FnMut(&T) -> T +{ + #[inline] + fn replicate(&mut self, source: &T, remaining: &mut usize) -> T { + let item = self(source); + *remaining = remaining.saturating_sub(1); + item + } +} + +// Note: Additional blanket impls for FnMut(&T, usize) -> T and FnMut(&T, &mut usize) -> T +// would conflict with the above. Users needing state awareness should implement Replicator directly +// or use a wrapper type. + +/// State-aware replicator wrapper for read-only access to remaining count. +pub struct ReadState(pub F); + +impl Replicator for ReadState +where F: FnMut(&T, usize) -> T +{ + #[inline] + fn replicate(&mut self, source: &T, remaining: &mut usize) -> T { + let item = (self.0)(source, *remaining); + *remaining = remaining.saturating_sub(1); + item + } +} + +/// State-aware replicator wrapper for mutable access to remaining count. +pub struct MutState(pub F); + +impl Replicator for MutState +where F: FnMut(&T, &mut usize) -> T +{ + #[inline] + fn replicate(&mut self, source: &T, remaining: &mut usize) -> T { (self.0)(source, remaining) } +} + +enum State { + Active { source: T, remaining: usize, rep_fn: R }, + Done, +} + +/// Iterator yielding N-1 replicas then the original. +/// +/// # Examples +/// +/// Simple cloning: +/// ``` +/// # use core::num::NonZeroUsize; +/// # use rep_move::RepMove; +/// let v = vec![1, 2, 3]; +/// let mut iter = RepMove::new(v, Vec::clone, NonZeroUsize::new(3).unwrap()); +/// +/// assert_eq!(iter.next(), Some(vec![1, 2, 3])); +/// assert_eq!(iter.next(), Some(vec![1, 2, 3])); +/// assert_eq!(iter.next(), Some(vec![1, 2, 3])); // moved +/// ``` +/// +/// Read-only state awareness: +/// ``` +/// # use core::num::NonZeroUsize; +/// # use rep_move::{RepMove, ReadState}; +/// let s = String::from("item"); +/// let mut iter = RepMove::new( +/// s, +/// ReadState(|s: &String, n| format!("{}-{}", s, n)), +/// NonZeroUsize::new(3).unwrap() +/// ); +/// +/// assert_eq!(iter.next(), Some("item-2".to_string())); +/// assert_eq!(iter.next(), Some("item-1".to_string())); +/// assert_eq!(iter.next(), Some("item".to_string())); +/// ``` +/// +/// Full control over iteration: +/// ``` +/// # use core::num::NonZeroUsize; +/// # use rep_move::{RepMove, MutState}; +/// let v = vec![1, 2, 3]; +/// let mut iter = RepMove::new( +/// v, +/// MutState(|v: &Vec, remaining: &mut usize| { +/// if v.len() > 10 { +/// *remaining = 0; // Stop early for large vectors +/// } else { +/// *remaining = remaining.saturating_sub(1); +/// } +/// v.clone() +/// }), +/// NonZeroUsize::new(5).unwrap() +/// ); +/// // Will yield fewer items due to the custom logic +/// ``` +pub struct RepMove> { + state: State, +} + +impl> RepMove { + /// Creates a new replicating iterator. + #[inline] + pub const fn new(source: T, rep_fn: R, count: usize) -> Self + where + T: [const] Destruct, + R: [const] Destruct, + { + if count == 0 { + Self { state: State::Done } + } else { + Self { state: State::Active { source, remaining: count - 1, rep_fn } } + } + } +} + +impl> Iterator for RepMove { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + let state = core::mem::replace(&mut self.state, State::Done); + + match state { + State::Active { source, mut remaining, mut rep_fn } => { + if remaining > 0 { + let item = rep_fn.replicate(&source, &mut remaining); + self.state = State::Active { source, remaining, rep_fn }; + Some(item) + } else { + Some(source) + } + } + State::Done => None, + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl> ExactSizeIterator for RepMove { + #[inline] + fn len(&self) -> usize { + match &self.state { + State::Active { remaining, .. } => remaining + 1, + State::Done => 0, + } + } +} + +impl> FusedIterator for RepMove {} + +impl> fmt::Debug for RepMove { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.state { + State::Active { source, remaining, .. } => f + .debug_struct("RepMove") + .field("source", source) + .field("remaining", remaining) + .finish_non_exhaustive(), + State::Done => write!(f, "RepMove::Done"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + extern crate alloc; + + use alloc::{ + format, + string::{String, ToString as _}, + vec, + vec::Vec, + }; + + #[test] + fn test_simple_clone() { + let v = vec![1, 2, 3]; + let mut iter = RepMove::new(v, Vec::clone, 3); + + assert_eq!(iter.len(), 3); + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.len(), 2); + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.len(), 1); + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.len(), 0); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_state_aware() { + let s = String::from("test"); + let mut iter = RepMove::new(s, ReadState(|s: &String, n| format!("{}-{}", s, n)), 2); + + assert_eq!(iter.next(), Some("test-1".to_string())); + assert_eq!(iter.next(), Some("test".to_string())); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_mutable_control() { + let v = vec![1, 2, 3]; + let mut iter = RepMove::new( + v, + MutState(|v: &Vec, remaining: &mut usize| { + if *remaining > 1 { + *remaining = 1; // Skip ahead + } else { + *remaining = remaining.saturating_sub(1); + } + v.clone() + }), + 4, + ); + + // Should yield fewer items due to skipping + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.next(), Some(vec![1, 2, 3])); + assert_eq!(iter.next(), None); + } +} diff --git a/patch/chrono-0.4.41/tests/dateutils.rs b/patch/chrono-0.4.41/tests/dateutils.rs deleted file mode 100644 index 849abc7..0000000 --- a/patch/chrono-0.4.41/tests/dateutils.rs +++ /dev/null @@ -1,165 +0,0 @@ -#![cfg(all(unix, feature = "clock", feature = "std"))] - -use std::{path, process, thread}; - -#[cfg(target_os = "linux")] -use chrono::Days; -use chrono::{Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike}; - -fn verify_against_date_command_local(path: &'static str, dt: NaiveDateTime) { - let output = process::Command::new(path) - .arg("-d") - .arg(format!("{}-{:02}-{:02} {:02}:05:01", dt.year(), dt.month(), dt.day(), dt.hour())) - .arg("+%Y-%m-%d %H:%M:%S %:z") - .output() - .unwrap(); - - let date_command_str = String::from_utf8(output.stdout).unwrap(); - - // The below would be preferred. At this stage neither earliest() or latest() - // seems to be consistent with the output of the `date` command, so we simply - // compare both. - // let local = Local - // .with_ymd_and_hms(year, month, day, hour, 5, 1) - // // looks like the "date" command always returns a given time when it is ambiguous - // .earliest(); - - // if let Some(local) = local { - // assert_eq!(format!("{}\n", local), date_command_str); - // } else { - // // we are in a "Spring forward gap" due to DST, and so date also returns "" - // assert_eq!("", date_command_str); - // } - - // This is used while a decision is made whether the `date` output needs to - // be exactly matched, or whether MappedLocalTime::Ambiguous should be handled - // differently - - let date = NaiveDate::from_ymd_opt(dt.year(), dt.month(), dt.day()).unwrap(); - match Local.from_local_datetime(&date.and_hms_opt(dt.hour(), 5, 1).unwrap()) { - chrono::MappedLocalTime::Ambiguous(a, b) => assert!( - format!("{}\n", a) == date_command_str || format!("{}\n", b) == date_command_str - ), - chrono::MappedLocalTime::Single(a) => { - assert_eq!(format!("{}\n", a), date_command_str); - } - chrono::MappedLocalTime::None => { - assert_eq!("", date_command_str); - } - } -} - -/// path to Unix `date` command. Should work on most Linux and Unixes. Not the -/// path for MacOS (/bin/date) which uses a different version of `date` with -/// different arguments (so it won't run which is okay). -/// for testing only -#[allow(dead_code)] -#[cfg(not(target_os = "aix"))] -const DATE_PATH: &str = "/usr/bin/date"; -#[allow(dead_code)] -#[cfg(target_os = "aix")] -const DATE_PATH: &str = "/opt/freeware/bin/date"; - -#[cfg(test)] -/// test helper to sanity check the date command behaves as expected -/// asserts the command succeeded -fn assert_run_date_version() { - // note environment variable `LANG` - match std::env::var_os("LANG") { - Some(lang) => eprintln!("LANG: {:?}", lang), - None => eprintln!("LANG not set"), - } - let out = process::Command::new(DATE_PATH).arg("--version").output().unwrap(); - let stdout = String::from_utf8(out.stdout).unwrap(); - let stderr = String::from_utf8(out.stderr).unwrap(); - // note the `date` binary version - eprintln!("command: {:?} --version\nstdout: {:?}\nstderr: {:?}", DATE_PATH, stdout, stderr); - assert!(out.status.success(), "command failed: {:?} --version", DATE_PATH); -} - -#[test] -fn try_verify_against_date_command() { - if !path::Path::new(DATE_PATH).exists() { - eprintln!("date command {:?} not found, skipping", DATE_PATH); - return; - } - assert_run_date_version(); - - eprintln!( - "Run command {:?} for every hour from 1975 to 2077, skipping some years...", - DATE_PATH, - ); - - let mut children = vec![]; - for year in [1975, 1976, 1977, 2020, 2021, 2022, 2073, 2074, 2075, 2076, 2077].iter() { - children.push(thread::spawn(|| { - let mut date = NaiveDate::from_ymd_opt(*year, 1, 1).unwrap().and_time(NaiveTime::MIN); - let end = NaiveDate::from_ymd_opt(*year + 1, 1, 1).unwrap().and_time(NaiveTime::MIN); - while date <= end { - verify_against_date_command_local(DATE_PATH, date); - date += chrono::TimeDelta::try_hours(1).unwrap(); - } - })); - } - for child in children { - // Wait for the thread to finish. Returns a result. - let _ = child.join(); - } -} - -#[cfg(target_os = "linux")] -fn verify_against_date_command_format_local(path: &'static str, dt: NaiveDateTime) { - let required_format = - "d%d D%D F%F H%H I%I j%j k%k l%l m%m M%M q%q S%S T%T u%u U%U w%w W%W X%X y%y Y%Y z%:z"; - // a%a - depends from localization - // A%A - depends from localization - // b%b - depends from localization - // B%B - depends from localization - // h%h - depends from localization - // c%c - depends from localization - // p%p - depends from localization - // r%r - depends from localization - // x%x - fails, date is dd/mm/yyyy, chrono is dd/mm/yy, same as %D - // Z%Z - too many ways to represent it, will most likely fail - - let output = process::Command::new(path) - .env("LANG", "c") - .env("LC_ALL", "c") - .arg("-d") - .arg(format!( - "{}-{:02}-{:02} {:02}:{:02}:{:02}", - dt.year(), - dt.month(), - dt.day(), - dt.hour(), - dt.minute(), - dt.second() - )) - .arg(format!("+{}", required_format)) - .output() - .unwrap(); - - let date_command_str = String::from_utf8(output.stdout).unwrap(); - let date = NaiveDate::from_ymd_opt(dt.year(), dt.month(), dt.day()).unwrap(); - let ldt = Local - .from_local_datetime(&date.and_hms_opt(dt.hour(), dt.minute(), dt.second()).unwrap()) - .unwrap(); - let formatted_date = format!("{}\n", ldt.format(required_format)); - assert_eq!(date_command_str, formatted_date); -} - -#[test] -#[cfg(target_os = "linux")] -fn try_verify_against_date_command_format() { - if !path::Path::new(DATE_PATH).exists() { - eprintln!("date command {:?} not found, skipping", DATE_PATH); - return; - } - assert_run_date_version(); - - let mut date = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_hms_opt(12, 11, 13).unwrap(); - while date.year() < 2008 { - verify_against_date_command_format_local(DATE_PATH, date); - date = date + Days::new(55); - } -} diff --git a/patch/chrono-0.4.41/tests/wasm.rs b/patch/chrono-0.4.41/tests/wasm.rs deleted file mode 100644 index ceb9b3d..0000000 --- a/patch/chrono-0.4.41/tests/wasm.rs +++ /dev/null @@ -1,89 +0,0 @@ -//! Run this test with: -//! `env TZ="$(date +%z)" NOW="$(date +%s)" wasm-pack test --node -- --features wasmbind` -//! -//! The `TZ` and `NOW` variables are used to compare the results inside the WASM environment with -//! the host system. -//! The check will fail if the local timezone does not match one of the timezones defined below. - -#![cfg(all( - target_arch = "wasm32", - feature = "wasmbind", - feature = "clock", - not(any(target_os = "emscripten", target_os = "wasi")) -))] - -use chrono::prelude::*; -use wasm_bindgen_test::*; - -#[wasm_bindgen_test] -fn now() { - let utc: DateTime = Utc::now(); - let local: DateTime = Local::now(); - - // Ensure time set by the test script is correct - let now = env!("NOW"); - let actual = NaiveDateTime::parse_from_str(&now, "%s").unwrap().and_utc(); - let diff = utc - actual; - assert!( - diff < chrono::TimeDelta::try_minutes(5).unwrap(), - "expected {} - {} == {} < 5m (env var: {})", - utc, - actual, - diff, - now, - ); - - let tz = env!("TZ"); - eprintln!("testing with tz={}", tz); - - // Ensure offset retrieved when getting local time is correct - let expected_offset = match tz { - "ACST-9:30" => FixedOffset::east_opt(19 * 30 * 60).unwrap(), - "Asia/Katmandu" => FixedOffset::east_opt(23 * 15 * 60).unwrap(), // No DST thankfully - "EDT" | "EST4" | "-0400" => FixedOffset::east_opt(-4 * 60 * 60).unwrap(), - "EST" | "-0500" => FixedOffset::east_opt(-5 * 60 * 60).unwrap(), - "UTC0" | "+0000" => FixedOffset::east_opt(0).unwrap(), - tz => panic!("unexpected TZ {}", tz), - }; - assert_eq!( - &expected_offset, - local.offset(), - "expected: {:?} local: {:?}", - expected_offset, - local.offset(), - ); -} - -#[wasm_bindgen_test] -fn from_is_exact() { - let now = js_sys::Date::new_0(); - - let dt = DateTime::::from(now.clone()); - - assert_eq!(now.get_time() as i64, dt.timestamp_millis()); -} - -#[wasm_bindgen_test] -fn local_from_local_datetime() { - let now = Local::now(); - let ndt = now.naive_local(); - let res = match Local.from_local_datetime(&ndt).single() { - Some(v) => v, - None => panic! {"Required for test!"}, - }; - assert_eq!(now, res); -} - -#[wasm_bindgen_test] -fn convert_all_parts_with_milliseconds() { - let time: DateTime = "2020-12-01T03:01:55.974Z".parse().unwrap(); - let js_date = js_sys::Date::from(time); - - assert_eq!(js_date.get_utc_full_year(), 2020); - assert_eq!(js_date.get_utc_month(), 11); // months are numbered 0..=11 - assert_eq!(js_date.get_utc_date(), 1); - assert_eq!(js_date.get_utc_hours(), 3); - assert_eq!(js_date.get_utc_minutes(), 1); - assert_eq!(js_date.get_utc_seconds(), 55); - assert_eq!(js_date.get_utc_milliseconds(), 974); -} diff --git a/patch/chrono-0.4.41/tests/win_bindings.rs b/patch/chrono-0.4.41/tests/win_bindings.rs deleted file mode 100644 index 2e28157..0000000 --- a/patch/chrono-0.4.41/tests/win_bindings.rs +++ /dev/null @@ -1,28 +0,0 @@ -#![cfg(all(windows, feature = "clock", feature = "std"))] - -use std::fs; -use windows_bindgen::bindgen; - -#[test] -fn gen_bindings() { - let input = "src/offset/local/win_bindings.txt"; - let output = "src/offset/local/win_bindings.rs"; - let existing = fs::read_to_string(output).unwrap(); - - bindgen(["--no-deps", "--etc", input]).unwrap(); - - // Check the output is the same as before. - // Depending on the git configuration the file may have been checked out with `\r\n` newlines or - // with `\n`. Compare line-by-line to ignore this difference. - let mut new = fs::read_to_string(output).unwrap(); - if existing.contains("\r\n") && !new.contains("\r\n") { - new = new.replace("\n", "\r\n"); - } else if !existing.contains("\r\n") && new.contains("\r\n") { - new = new.replace("\r\n", "\n"); - } - - similar_asserts::assert_eq!(existing, new); - if !new.lines().eq(existing.lines()) { - panic!("generated file `{output}` is changed."); - } -} diff --git a/patch/chrono-0.4.41/CITATION.cff b/patch/chrono-0.4.42/CITATION.cff similarity index 100% rename from patch/chrono-0.4.41/CITATION.cff rename to patch/chrono-0.4.42/CITATION.cff diff --git a/patch/chrono-0.4.41/Cargo.toml b/patch/chrono-0.4.42/Cargo.toml similarity index 88% rename from patch/chrono-0.4.41/Cargo.toml rename to patch/chrono-0.4.42/Cargo.toml index ed4b65b..507b61a 100644 --- a/patch/chrono-0.4.41/Cargo.toml +++ b/patch/chrono-0.4.42/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chrono" -version = "0.4.41" +version = "0.4.42" description = "Date and time library for Rust" homepage = "https://github.com/chronotope/chrono" documentation = "https://docs.rs/chrono/" @@ -25,14 +25,16 @@ winapi = ["windows-link"] std = ["alloc"] clock = ["winapi", "iana-time-zone", "now"] now = ["std"] +core-error = [] oldtime = [] wasmbind = ["wasm-bindgen", "js-sys"] unstable-locales = ["pure-rust-locales"] # Note that rkyv-16, rkyv-32, and rkyv-64 are mutually exclusive. +rkyv = ["dep:rkyv", "rkyv/pointer_width_32"] rkyv-16 = ["dep:rkyv", "rkyv?/pointer_width_16"] rkyv-32 = ["dep:rkyv", "rkyv?/pointer_width_32"] rkyv-64 = ["dep:rkyv", "rkyv?/pointer_width_64"] -rkyv-validation = ["rkyv?/bytecheck"] +rkyv-validation = ["rkyv?/validation"] # Features for internal use only: __internal_bench = [] @@ -40,7 +42,7 @@ __internal_bench = [] num-traits = { version = "0.2", default-features = false } serde = { version = "1.0.99", default-features = false, optional = true } pure-rust-locales = { version = "0.8", optional = true } -rkyv = { version = "0.8.10", optional = true, default-features = false, features = ["std"]} +rkyv = { version = "0.8", optional = true, default-features = false } arbitrary = { version = "1.0.0", features = ["derive"], optional = true } [target.'cfg(all(target_arch = "wasm32", not(any(target_os = "emscripten", target_os = "wasi"))))'.dependencies] @@ -48,10 +50,10 @@ wasm-bindgen = { version = "0.2", optional = true } js-sys = { version = "0.3", optional = true } # contains FFI bindings for the JS Date API [target.'cfg(windows)'.dependencies] -windows-link = { version = "0.1", optional = true } +windows-link = { version = "0.2", optional = true } [target.'cfg(windows)'.dev-dependencies] -windows-bindgen = { version = "0.62" } # MSRV is 1.74 +windows-bindgen = { version = "0.63" } # MSRV is 1.74 [target.'cfg(unix)'.dependencies] iana-time-zone = { version = "0.1.45", optional = true, features = ["fallback"] } diff --git a/patch/chrono-0.4.41/LICENSE.txt b/patch/chrono-0.4.42/LICENSE.txt similarity index 100% rename from patch/chrono-0.4.41/LICENSE.txt rename to patch/chrono-0.4.42/LICENSE.txt diff --git a/patch/chrono-0.4.41/README.md b/patch/chrono-0.4.42/README.md similarity index 100% rename from patch/chrono-0.4.41/README.md rename to patch/chrono-0.4.42/README.md diff --git a/patch/chrono-0.4.41/src/date.rs b/patch/chrono-0.4.42/src/date.rs similarity index 99% rename from patch/chrono-0.4.41/src/date.rs rename to patch/chrono-0.4.42/src/date.rs index ec05b27..a66882c 100644 --- a/patch/chrono-0.4.41/src/date.rs +++ b/patch/chrono-0.4.42/src/date.rs @@ -10,8 +10,8 @@ use core::cmp::Ordering; use core::ops::{Add, AddAssign, Sub, SubAssign}; use core::{fmt, hash}; -// #[cfg(feature = "rkyv")] -// use rkyv::{Archive, Deserialize, Serialize}; +#[cfg(feature = "rkyv")] +use rkyv::{Archive, Deserialize, Serialize}; #[cfg(all(feature = "unstable-locales", feature = "alloc"))] use crate::format::Locale; @@ -54,7 +54,7 @@ use crate::{DateTime, Datelike, TimeDelta, Weekday}; /// even though the raw calculation between `NaiveDate` and `TimeDelta` may not. #[deprecated(since = "0.4.23", note = "Use `NaiveDate` or `DateTime` instead")] #[derive(Clone)] -// #[cfg_attr(feature = "rkyv", derive(Archive, Deserialize, Serialize))] +#[cfg_attr(feature = "rkyv", derive(Archive, Deserialize, Serialize))] pub struct Date { date: NaiveDate, offset: Tz::Offset, diff --git a/patch/chrono-0.4.41/src/datetime/mod.rs b/patch/chrono-0.4.42/src/datetime/mod.rs similarity index 96% rename from patch/chrono-0.4.41/src/datetime/mod.rs rename to patch/chrono-0.4.42/src/datetime/mod.rs index 8647393..023bede 100644 --- a/patch/chrono-0.4.41/src/datetime/mod.rs +++ b/patch/chrono-0.4.42/src/datetime/mod.rs @@ -31,7 +31,7 @@ use crate::offset::{FixedOffset, LocalResult, Offset, TimeZone, Utc}; use crate::{Datelike, Months, TimeDelta, Timelike, Weekday}; use crate::{expect, try_opt}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; /// documented at re-export site @@ -48,7 +48,7 @@ mod tests; /// [`TimeZone`](./offset/trait.TimeZone.html) implementations. #[derive(Clone)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)) )] @@ -713,6 +713,61 @@ impl DateTime { } impl DateTime { + /// Makes a new `DateTime` from the number of non-leap seconds + /// since January 1, 1970 0:00:00 UTC (aka "UNIX timestamp"). + /// + /// This is a convenience wrapper around [`DateTime::from_timestamp`], + /// which is useful in functions like [`Iterator::map`] to avoid a closure. + /// + /// This is guaranteed to round-trip with regard to [`timestamp`](DateTime::timestamp). + /// + /// If you need to create a `DateTime` with a [`TimeZone`] different from [`Utc`], use + /// [`TimeZone::timestamp_opt`] or [`DateTime::with_timezone`]; if you need to create a + /// `DateTime` with more precision, use [`DateTime::from_timestamp_micros`], + /// [`DateTime::from_timestamp_millis`], or [`DateTime::from_timestamp_nanos`]. + /// + /// # Errors + /// + /// Returns `None` on out-of-range number of seconds, + /// otherwise returns `Some(DateTime {...})`. + /// + /// # Examples + /// + /// Using [`Option::and_then`]: + /// + /// ``` + /// # use chrono::DateTime; + /// let maybe_timestamp: Option = Some(1431648000); + /// let maybe_dt = maybe_timestamp.and_then(DateTime::from_timestamp_secs); + /// + /// assert!(maybe_dt.is_some()); + /// assert_eq!(maybe_dt.unwrap().to_string(), "2015-05-15 00:00:00 UTC"); + /// ``` + /// + /// Using [`Iterator::map`]: + /// + /// ``` + /// # use chrono::{DateTime, Utc}; + /// let v = vec![i64::MIN, 1_000_000_000, 1_234_567_890, i64::MAX]; + /// let timestamps: Vec>> = v + /// .into_iter() + /// .map(DateTime::from_timestamp_secs) + /// .collect(); + /// + /// assert_eq!(vec![ + /// None, + /// Some(DateTime::parse_from_rfc3339("2001-09-09 01:46:40Z").unwrap().to_utc()), + /// Some(DateTime::parse_from_rfc3339("2009-02-13 23:31:30Z").unwrap().to_utc()), + /// None, + /// ], timestamps); + /// ``` + /// + #[inline] + #[must_use] + pub const fn from_timestamp_secs(secs: i64) -> Option { + Self::from_timestamp(secs, 0) + } + /// Makes a new `DateTime` from the number of non-leap seconds /// since January 1, 1970 0:00:00 UTC (aka "UNIX timestamp") /// and the number of nanoseconds since the last whole non-leap second. @@ -1930,12 +1985,12 @@ where } } -/// Number of days between Januari 1, 1970 and December 31, 1 BCE which we define to be day 0. +/// Number of days between January 1, 1970 and December 31, 1 BCE which we define to be day 0. /// 4 full leap year cycles until December 31, 1600 4 * 146097 = 584388 /// 1 day until January 1, 1601 1 -/// 369 years until Januari 1, 1970 369 * 365 = 134685 +/// 369 years until January 1, 1970 369 * 365 = 134685 /// of which floor(369 / 4) are leap years floor(369 / 4) = 92 /// except for 1700, 1800 and 1900 -3 + /// -------- /// 719163 -const UNIX_EPOCH_DAY: i64 = 719_163; +pub(crate) const UNIX_EPOCH_DAY: i64 = 719_163; diff --git a/patch/chrono-0.4.41/src/datetime/serde.rs b/patch/chrono-0.4.42/src/datetime/serde.rs similarity index 98% rename from patch/chrono-0.4.41/src/datetime/serde.rs rename to patch/chrono-0.4.42/src/datetime/serde.rs index 6e009a3..b08e4e1 100644 --- a/patch/chrono-0.4.41/src/datetime/serde.rs +++ b/patch/chrono-0.4.42/src/datetime/serde.rs @@ -50,8 +50,8 @@ impl ser::Serialize for DateTime { } } -#[doc(hidden)] -#[derive(Debug)] +#[allow(missing_docs)] +#[allow(missing_debug_implementations)] pub struct DateTimeVisitor; impl de::Visitor<'_> for DateTimeVisitor { @@ -244,11 +244,7 @@ pub mod ts_nanoseconds { where E: de::Error, { - DateTime::from_timestamp( - value.div_euclid(1_000_000_000), - (value.rem_euclid(1_000_000_000)) as u32, - ) - .ok_or_else(|| invalid_ts(value)) + Ok(DateTime::from_timestamp_nanos(value)) } /// Deserialize a timestamp in nanoseconds since the epoch @@ -526,11 +522,7 @@ pub mod ts_microseconds { where E: de::Error, { - DateTime::from_timestamp( - value.div_euclid(1_000_000), - (value.rem_euclid(1_000_000) * 1000) as u32, - ) - .ok_or_else(|| invalid_ts(value)) + DateTime::from_timestamp_micros(value).ok_or_else(|| invalid_ts(value)) } /// Deserialize a timestamp in milliseconds since the epoch @@ -1066,7 +1058,7 @@ pub mod ts_seconds { where E: de::Error, { - DateTime::from_timestamp(value, 0).ok_or_else(|| invalid_ts(value)) + DateTime::from_timestamp_secs(value).ok_or_else(|| invalid_ts(value)) } /// Deserialize a timestamp in seconds since the epoch @@ -1077,7 +1069,7 @@ pub mod ts_seconds { if value > i64::MAX as u64 { Err(invalid_ts(value)) } else { - DateTime::from_timestamp(value as i64, 0).ok_or_else(|| invalid_ts(value)) + DateTime::from_timestamp_secs(value as i64).ok_or_else(|| invalid_ts(value)) } } } diff --git a/patch/chrono-0.4.41/src/datetime/tests.rs b/patch/chrono-0.4.42/src/datetime/tests.rs similarity index 98% rename from patch/chrono-0.4.41/src/datetime/tests.rs rename to patch/chrono-0.4.42/src/datetime/tests.rs index f96d46b..67bee10 100644 --- a/patch/chrono-0.4.41/src/datetime/tests.rs +++ b/patch/chrono-0.4.42/src/datetime/tests.rs @@ -154,7 +154,10 @@ fn test_datetime_from_timestamp_millis() { // that of `from_timestamp_opt`. let secs_test = [0, 1, 2, 1000, 1234, 12345678, -1, -2, -1000, -12345678]; for secs in secs_test.iter().cloned() { - assert_eq!(DateTime::from_timestamp_millis(secs * 1000), DateTime::from_timestamp(secs, 0)); + assert_eq!( + DateTime::from_timestamp_millis(secs * 1000), + DateTime::from_timestamp_secs(secs) + ); } } @@ -191,7 +194,7 @@ fn test_datetime_from_timestamp_micros() { for secs in secs_test.iter().copied() { assert_eq!( DateTime::from_timestamp_micros(secs * 1_000_000), - DateTime::from_timestamp(secs, 0) + DateTime::from_timestamp_secs(secs) ); } } @@ -242,24 +245,34 @@ fn test_datetime_from_timestamp_nanos() { for secs in secs_test.iter().copied() { assert_eq!( Some(DateTime::from_timestamp_nanos(secs * 1_000_000_000)), - DateTime::from_timestamp(secs, 0) + DateTime::from_timestamp_secs(secs) ); } } +#[test] +fn test_datetime_from_timestamp_secs() { + let valid = [-2208936075, 0, 119731017, 1234567890, 2034061609]; + + for timestamp_secs in valid.iter().copied() { + let datetime = DateTime::from_timestamp_secs(timestamp_secs).unwrap(); + assert_eq!(timestamp_secs, datetime.timestamp()); + assert_eq!(DateTime::from_timestamp(timestamp_secs, 0).unwrap(), datetime); + } +} + #[test] fn test_datetime_from_timestamp() { - let from_timestamp = |secs| DateTime::from_timestamp(secs, 0); let ymdhms = |y, m, d, h, n, s| { NaiveDate::from_ymd_opt(y, m, d).unwrap().and_hms_opt(h, n, s).unwrap().and_utc() }; - assert_eq!(from_timestamp(-1), Some(ymdhms(1969, 12, 31, 23, 59, 59))); - assert_eq!(from_timestamp(0), Some(ymdhms(1970, 1, 1, 0, 0, 0))); - assert_eq!(from_timestamp(1), Some(ymdhms(1970, 1, 1, 0, 0, 1))); - assert_eq!(from_timestamp(1_000_000_000), Some(ymdhms(2001, 9, 9, 1, 46, 40))); - assert_eq!(from_timestamp(0x7fffffff), Some(ymdhms(2038, 1, 19, 3, 14, 7))); - assert_eq!(from_timestamp(i64::MIN), None); - assert_eq!(from_timestamp(i64::MAX), None); + assert_eq!(DateTime::from_timestamp_secs(-1), Some(ymdhms(1969, 12, 31, 23, 59, 59))); + assert_eq!(DateTime::from_timestamp_secs(0), Some(ymdhms(1970, 1, 1, 0, 0, 0))); + assert_eq!(DateTime::from_timestamp_secs(1), Some(ymdhms(1970, 1, 1, 0, 0, 1))); + assert_eq!(DateTime::from_timestamp_secs(1_000_000_000), Some(ymdhms(2001, 9, 9, 1, 46, 40))); + assert_eq!(DateTime::from_timestamp_secs(0x7fffffff), Some(ymdhms(2038, 1, 19, 3, 14, 7))); + assert_eq!(DateTime::from_timestamp_secs(i64::MIN), None); + assert_eq!(DateTime::from_timestamp_secs(i64::MAX), None); } #[test] @@ -1034,7 +1047,7 @@ fn test_parse_datetime_utc() { Ok(d) => d, Err(e) => panic!("parsing `{s}` has failed: {e}"), }; - let s_ = format!("{:?}", d); + let s_ = format!("{d:?}"); // `s` and `s_` may differ, but `s.parse()` and `s_.parse()` must be same let d_ = match s_.parse::>() { Ok(d) => d, diff --git a/patch/chrono-0.4.41/src/format/formatting.rs b/patch/chrono-0.4.42/src/format/formatting.rs similarity index 97% rename from patch/chrono-0.4.41/src/format/formatting.rs rename to patch/chrono-0.4.42/src/format/formatting.rs index 3b37a15..79ed694 100644 --- a/patch/chrono-0.4.41/src/format/formatting.rs +++ b/patch/chrono-0.4.42/src/format/formatting.rs @@ -109,7 +109,7 @@ impl<'a, I: Iterator + Clone, B: Borrow>> DelayedFormat { /// let mut buffer = String::new(); /// let _ = df.write_to(&mut buffer); /// ``` - pub fn write_to(&self, w: &mut impl Write) -> fmt::Result { + pub fn write_to(&self, w: &mut (impl Write + ?Sized)) -> fmt::Result { for item in self.items.clone() { match *item.borrow() { Item::Literal(s) | Item::Space(s) => w.write_str(s), @@ -124,14 +124,19 @@ impl<'a, I: Iterator + Clone, B: Borrow>> DelayedFormat { } #[cfg(feature = "alloc")] - fn format_numeric(&self, w: &mut impl Write, spec: &Numeric, pad: Pad) -> fmt::Result { + fn format_numeric( + &self, + w: &mut (impl Write + ?Sized), + spec: &Numeric, + pad: Pad, + ) -> fmt::Result { use self::Numeric::*; - fn write_one(w: &mut impl Write, v: u8) -> fmt::Result { + fn write_one(w: &mut (impl Write + ?Sized), v: u8) -> fmt::Result { w.write_char((b'0' + v) as char) } - fn write_two(w: &mut impl Write, v: u8, pad: Pad) -> fmt::Result { + fn write_two(w: &mut (impl Write + ?Sized), v: u8, pad: Pad) -> fmt::Result { let ones = b'0' + v % 10; match (v / 10, pad) { (0, Pad::None) => {} @@ -142,7 +147,7 @@ impl<'a, I: Iterator + Clone, B: Borrow>> DelayedFormat { } #[inline] - fn write_year(w: &mut impl Write, year: i32, pad: Pad) -> fmt::Result { + fn write_year(w: &mut (impl Write + ?Sized), year: i32, pad: Pad) -> fmt::Result { if (1000..=9999).contains(&year) { // fast path write_hundreds(w, (year / 100) as u8)?; @@ -153,7 +158,7 @@ impl<'a, I: Iterator + Clone, B: Borrow>> DelayedFormat { } fn write_n( - w: &mut impl Write, + w: &mut (impl Write + ?Sized), n: usize, v: i64, pad: Pad, @@ -214,7 +219,7 @@ impl<'a, I: Iterator + Clone, B: Borrow>> DelayedFormat { } #[cfg(feature = "alloc")] - fn format_fixed(&self, w: &mut impl Write, spec: &Fixed) -> fmt::Result { + fn format_fixed(&self, w: &mut (impl Write + ?Sized), spec: &Fixed) -> fmt::Result { use Fixed::*; use InternalInternal::*; @@ -387,7 +392,7 @@ pub fn format_item( #[cfg(any(feature = "alloc", feature = "serde"))] impl OffsetFormat { /// Writes an offset from UTC with the format defined by `self`. - fn format(&self, w: &mut impl Write, off: FixedOffset) -> fmt::Result { + fn format(&self, w: &mut (impl Write + ?Sized), off: FixedOffset) -> fmt::Result { let off = off.local_minus_utc(); if self.allow_zulu && off == 0 { w.write_char('Z')?; @@ -495,8 +500,8 @@ pub enum SecondsFormat { /// Writes the date, time and offset to the string. same as `%Y-%m-%dT%H:%M:%S%.f%:z` #[inline] #[cfg(any(feature = "alloc", feature = "serde"))] -pub(crate) fn write_rfc3339( - w: &mut impl Write, +pub fn write_rfc3339( + w: &mut (impl Write + ?Sized), dt: NaiveDateTime, off: FixedOffset, secform: SecondsFormat, @@ -560,7 +565,7 @@ pub(crate) fn write_rfc3339( #[cfg(feature = "alloc")] /// write datetimes like `Tue, 1 Jul 2003 10:52:37 +0200`, same as `%a, %d %b %Y %H:%M:%S %z` pub(crate) fn write_rfc2822( - w: &mut impl Write, + w: &mut (impl Write + ?Sized), dt: NaiveDateTime, off: FixedOffset, ) -> fmt::Result { @@ -605,7 +610,7 @@ pub(crate) fn write_rfc2822( } /// Equivalent to `{:02}` formatting for n < 100. -pub(crate) fn write_hundreds(w: &mut impl Write, n: u8) -> fmt::Result { +pub(crate) fn write_hundreds(w: &mut (impl Write + ?Sized), n: u8) -> fmt::Result { if n >= 100 { return Err(fmt::Error); } diff --git a/patch/chrono-0.4.41/src/format/locales.rs b/patch/chrono-0.4.42/src/format/locales.rs similarity index 100% rename from patch/chrono-0.4.41/src/format/locales.rs rename to patch/chrono-0.4.42/src/format/locales.rs diff --git a/patch/chrono-0.4.41/src/format/mod.rs b/patch/chrono-0.4.42/src/format/mod.rs similarity index 99% rename from patch/chrono-0.4.41/src/format/mod.rs rename to patch/chrono-0.4.42/src/format/mod.rs index 241be7a..f90314a 100644 --- a/patch/chrono-0.4.41/src/format/mod.rs +++ b/patch/chrono-0.4.42/src/format/mod.rs @@ -33,6 +33,8 @@ #[cfg(all(feature = "alloc", not(feature = "std"), not(test)))] use alloc::boxed::Box; +#[cfg(all(feature = "core-error", not(feature = "std")))] +use core::error::Error; use core::fmt; use core::str::FromStr; #[cfg(feature = "std")] @@ -59,7 +61,7 @@ pub(crate) use formatting::write_hundreds; #[cfg(feature = "alloc")] pub(crate) use formatting::write_rfc2822; #[cfg(any(feature = "alloc", feature = "serde"))] -pub(crate) use formatting::write_rfc3339; +pub use formatting::write_rfc3339; #[cfg(feature = "alloc")] #[allow(deprecated)] pub use formatting::{DelayedFormat, format, format_item}; @@ -450,7 +452,7 @@ impl fmt::Display for ParseError { } } -#[cfg(feature = "std")] +#[cfg(any(feature = "core-error", feature = "std"))] impl Error for ParseError { #[allow(deprecated)] fn description(&self) -> &str { diff --git a/patch/chrono-0.4.41/src/format/parse.rs b/patch/chrono-0.4.42/src/format/parse.rs similarity index 99% rename from patch/chrono-0.4.41/src/format/parse.rs rename to patch/chrono-0.4.42/src/format/parse.rs index b982617..2e3b298 100644 --- a/patch/chrono-0.4.41/src/format/parse.rs +++ b/patch/chrono-0.4.42/src/format/parse.rs @@ -1878,7 +1878,7 @@ mod tests { if dt != checkdate { // check for expected result panic!( - "Date conversion failed for {date}\nReceived: {dt:?}\nExpected: checkdate{:?}" + "Date conversion failed for {date}\nReceived: {dt:?}\nExpected: {checkdate:?}" ); } } diff --git a/patch/chrono-0.4.41/src/format/parsed.rs b/patch/chrono-0.4.42/src/format/parsed.rs similarity index 99% rename from patch/chrono-0.4.41/src/format/parsed.rs rename to patch/chrono-0.4.42/src/format/parsed.rs index fd5008f..8524daa 100644 --- a/patch/chrono-0.4.41/src/format/parsed.rs +++ b/patch/chrono-0.4.42/src/format/parsed.rs @@ -832,7 +832,7 @@ impl Parsed { // reconstruct date and time fields from timestamp let ts = timestamp.checked_add(i64::from(offset)).ok_or(OUT_OF_RANGE)?; - let mut datetime = DateTime::from_timestamp(ts, 0).ok_or(OUT_OF_RANGE)?.naive_utc(); + let mut datetime = DateTime::from_timestamp_secs(ts).ok_or(OUT_OF_RANGE)?.naive_utc(); // fill year, ordinal, hour, minute and second fields from timestamp. // if existing fields are consistent, this will allow the full date/time reconstruction. diff --git a/patch/chrono-0.4.41/src/format/scan.rs b/patch/chrono-0.4.42/src/format/scan.rs similarity index 100% rename from patch/chrono-0.4.41/src/format/scan.rs rename to patch/chrono-0.4.42/src/format/scan.rs diff --git a/patch/chrono-0.4.41/src/format/strftime.rs b/patch/chrono-0.4.42/src/format/strftime.rs similarity index 79% rename from patch/chrono-0.4.41/src/format/strftime.rs rename to patch/chrono-0.4.42/src/format/strftime.rs index 1449eee..5dc1180 100644 --- a/patch/chrono-0.4.41/src/format/strftime.rs +++ b/patch/chrono-0.4.42/src/format/strftime.rs @@ -253,8 +253,7 @@ impl<'a> StrftimeItems<'a> { /// const ITEMS: &[Item<'static>] = &[ /// Item::Numeric(Numeric::Year, Pad::Zero), /// Item::Literal("-"), - /// Item::Literal("%"), - /// Item::Literal("Q"), + /// Item::Literal("%Q"), /// ]; /// println!("{:?}", strftime_parser.clone().collect::>()); /// assert!(strftime_parser.eq(ITEMS.iter().cloned())); @@ -425,9 +424,247 @@ impl<'a> StrftimeItems<'a> { }) .collect() } -} -const HAVE_ALTERNATES: &str = "z"; + fn parse_next_item(&mut self, mut remainder: &'a str) -> Option<(&'a str, Item<'a>)> { + use InternalInternal::*; + use Item::{Literal, Space}; + use Numeric::*; + + let (original, mut remainder) = match remainder.chars().next()? { + // the next item is a specifier + '%' => (remainder, &remainder[1..]), + + // the next item is space + c if c.is_whitespace() => { + // `%` is not a whitespace, so `c != '%'` is redundant + let nextspec = + remainder.find(|c: char| !c.is_whitespace()).unwrap_or(remainder.len()); + assert!(nextspec > 0); + let item = Space(&remainder[..nextspec]); + remainder = &remainder[nextspec..]; + return Some((remainder, item)); + } + + // the next item is literal + _ => { + let nextspec = remainder + .find(|c: char| c.is_whitespace() || c == '%') + .unwrap_or(remainder.len()); + assert!(nextspec > 0); + let item = Literal(&remainder[..nextspec]); + remainder = &remainder[nextspec..]; + return Some((remainder, item)); + } + }; + + macro_rules! next { + () => { + match remainder.chars().next() { + Some(x) => { + remainder = &remainder[x.len_utf8()..]; + x + } + None => return Some((remainder, self.error(original, remainder))), // premature end of string + } + }; + } + + let spec = next!(); + let pad_override = match spec { + '-' => Some(Pad::None), + '0' => Some(Pad::Zero), + '_' => Some(Pad::Space), + _ => None, + }; + + let is_alternate = spec == '#'; + let spec = if pad_override.is_some() || is_alternate { next!() } else { spec }; + if is_alternate && !HAVE_ALTERNATES.contains(spec) { + return Some((remainder, self.error(original, remainder))); + } + + macro_rules! queue { + [$head:expr, $($tail:expr),+ $(,)*] => ({ + const QUEUE: &'static [Item<'static>] = &[$($tail),+]; + self.queue = QUEUE; + $head + }) + } + + #[cfg(not(feature = "unstable-locales"))] + macro_rules! queue_from_slice { + ($slice:expr) => {{ + self.queue = &$slice[1..]; + $slice[0].clone() + }}; + } + + let item = match spec { + 'A' => fixed(Fixed::LongWeekdayName), + 'B' => fixed(Fixed::LongMonthName), + 'C' => num0(YearDiv100), + 'D' => { + queue![num0(Month), Literal("/"), num0(Day), Literal("/"), num0(YearMod100)] + } + 'F' => queue![num0(Year), Literal("-"), num0(Month), Literal("-"), num0(Day)], + 'G' => num0(IsoYear), + 'H' => num0(Hour), + 'I' => num0(Hour12), + 'M' => num0(Minute), + 'P' => fixed(Fixed::LowerAmPm), + 'R' => queue![num0(Hour), Literal(":"), num0(Minute)], + 'S' => num0(Second), + 'T' => { + queue![num0(Hour), Literal(":"), num0(Minute), Literal(":"), num0(Second)] + } + 'U' => num0(WeekFromSun), + 'V' => num0(IsoWeek), + 'W' => num0(WeekFromMon), + #[cfg(not(feature = "unstable-locales"))] + 'X' => queue_from_slice!(T_FMT), + #[cfg(feature = "unstable-locales")] + 'X' => self.switch_to_locale_str(locales::t_fmt, T_FMT), + 'Y' => num0(Year), + 'Z' => fixed(Fixed::TimezoneName), + 'a' => fixed(Fixed::ShortWeekdayName), + 'b' | 'h' => fixed(Fixed::ShortMonthName), + #[cfg(not(feature = "unstable-locales"))] + 'c' => queue_from_slice!(D_T_FMT), + #[cfg(feature = "unstable-locales")] + 'c' => self.switch_to_locale_str(locales::d_t_fmt, D_T_FMT), + 'd' => num0(Day), + 'e' => nums(Day), + 'f' => num0(Nanosecond), + 'g' => num0(IsoYearMod100), + 'j' => num0(Ordinal), + 'k' => nums(Hour), + 'l' => nums(Hour12), + 'm' => num0(Month), + 'n' => Space("\n"), + 'p' => fixed(Fixed::UpperAmPm), + 'q' => num(Quarter), + #[cfg(not(feature = "unstable-locales"))] + 'r' => queue_from_slice!(T_FMT_AMPM), + #[cfg(feature = "unstable-locales")] + 'r' => { + if self.locale.is_some() && locales::t_fmt_ampm(self.locale.unwrap()).is_empty() { + // 12-hour clock not supported by this locale. Switch to 24-hour format. + self.switch_to_locale_str(locales::t_fmt, T_FMT) + } else { + self.switch_to_locale_str(locales::t_fmt_ampm, T_FMT_AMPM) + } + } + 's' => num(Timestamp), + 't' => Space("\t"), + 'u' => num(WeekdayFromMon), + 'v' => { + queue![ + nums(Day), + Literal("-"), + fixed(Fixed::ShortMonthName), + Literal("-"), + num0(Year) + ] + } + 'w' => num(NumDaysFromSun), + #[cfg(not(feature = "unstable-locales"))] + 'x' => queue_from_slice!(D_FMT), + #[cfg(feature = "unstable-locales")] + 'x' => self.switch_to_locale_str(locales::d_fmt, D_FMT), + 'y' => num0(YearMod100), + 'z' => { + if is_alternate { + internal_fixed(TimezoneOffsetPermissive) + } else { + fixed(Fixed::TimezoneOffset) + } + } + '+' => fixed(Fixed::RFC3339), + ':' => { + if remainder.starts_with("::z") { + remainder = &remainder[3..]; + fixed(Fixed::TimezoneOffsetTripleColon) + } else if remainder.starts_with(":z") { + remainder = &remainder[2..]; + fixed(Fixed::TimezoneOffsetDoubleColon) + } else if remainder.starts_with('z') { + remainder = &remainder[1..]; + fixed(Fixed::TimezoneOffsetColon) + } else { + self.error(original, remainder) + } + } + '.' => match next!() { + '3' => match next!() { + 'f' => fixed(Fixed::Nanosecond3), + _ => self.error(original, remainder), + }, + '6' => match next!() { + 'f' => fixed(Fixed::Nanosecond6), + _ => self.error(original, remainder), + }, + '9' => match next!() { + 'f' => fixed(Fixed::Nanosecond9), + _ => self.error(original, remainder), + }, + 'f' => fixed(Fixed::Nanosecond), + _ => self.error(original, remainder), + }, + '3' => match next!() { + 'f' => internal_fixed(Nanosecond3NoDot), + _ => self.error(original, remainder), + }, + '6' => match next!() { + 'f' => internal_fixed(Nanosecond6NoDot), + _ => self.error(original, remainder), + }, + '9' => match next!() { + 'f' => internal_fixed(Nanosecond9NoDot), + _ => self.error(original, remainder), + }, + '%' => Literal("%"), + _ => self.error(original, remainder), + }; + + // Adjust `item` if we have any padding modifier. + // Not allowed on non-numeric items or on specifiers composed out of multiple + // formatting items. + if let Some(new_pad) = pad_override { + match item { + Item::Numeric(ref kind, _pad) if self.queue.is_empty() => { + Some((remainder, Item::Numeric(kind.clone(), new_pad))) + } + _ => Some((remainder, self.error(original, remainder))), + } + } else { + Some((remainder, item)) + } + } + + fn error<'b>(&mut self, original: &'b str, remainder: &'b str) -> Item<'b> { + match self.lenient { + false => Item::Error, + true => Item::Literal(&original[..original.len() - remainder.len()]), + } + } + + #[cfg(feature = "unstable-locales")] + fn switch_to_locale_str( + &mut self, + localized_fmt_str: impl Fn(Locale) -> &'static str, + fallback: &'static [Item<'static>], + ) -> Item<'a> { + if let Some(locale) = self.locale { + assert!(self.locale_str.is_empty()); + let (fmt_str, item) = self.parse_next_item(localized_fmt_str(locale)).unwrap(); + self.locale_str = fmt_str; + item + } else { + self.queue = &fallback[1..]; + fallback[0].clone() + } + } +} impl<'a> Iterator for StrftimeItems<'a> { type Item = Item<'a>; @@ -454,330 +691,46 @@ impl<'a> Iterator for StrftimeItems<'a> { } } -impl<'a> StrftimeItems<'a> { - fn error<'b>( - &mut self, - original: &'b str, - error_len: &mut usize, - ch: Option, - ) -> (&'b str, Item<'b>) { - if !self.lenient { - return (&original[*error_len..], Item::Error); - } +static D_FMT: &[Item<'static>] = &[ + num0(Numeric::Month), + Item::Literal("/"), + num0(Numeric::Day), + Item::Literal("/"), + num0(Numeric::YearMod100), +]; +static D_T_FMT: &[Item<'static>] = &[ + fixed(Fixed::ShortWeekdayName), + Item::Space(" "), + fixed(Fixed::ShortMonthName), + Item::Space(" "), + nums(Numeric::Day), + Item::Space(" "), + num0(Numeric::Hour), + Item::Literal(":"), + num0(Numeric::Minute), + Item::Literal(":"), + num0(Numeric::Second), + Item::Space(" "), + num0(Numeric::Year), +]; +static T_FMT: &[Item<'static>] = &[ + num0(Numeric::Hour), + Item::Literal(":"), + num0(Numeric::Minute), + Item::Literal(":"), + num0(Numeric::Second), +]; +static T_FMT_AMPM: &[Item<'static>] = &[ + num0(Numeric::Hour12), + Item::Literal(":"), + num0(Numeric::Minute), + Item::Literal(":"), + num0(Numeric::Second), + Item::Space(" "), + fixed(Fixed::UpperAmPm), +]; - if let Some(c) = ch { - *error_len -= c.len_utf8(); - } - (&original[*error_len..], Item::Literal(&original[..*error_len])) - } - - fn parse_next_item(&mut self, mut remainder: &'a str) -> Option<(&'a str, Item<'a>)> { - use InternalInternal::*; - use Item::{Literal, Space}; - use Numeric::*; - - static D_FMT: &[Item<'static>] = - &[num0(Month), Literal("/"), num0(Day), Literal("/"), num0(YearMod100)]; - static D_T_FMT: &[Item<'static>] = &[ - fixed(Fixed::ShortWeekdayName), - Space(" "), - fixed(Fixed::ShortMonthName), - Space(" "), - nums(Day), - Space(" "), - num0(Hour), - Literal(":"), - num0(Minute), - Literal(":"), - num0(Second), - Space(" "), - num0(Year), - ]; - static T_FMT: &[Item<'static>] = - &[num0(Hour), Literal(":"), num0(Minute), Literal(":"), num0(Second)]; - static T_FMT_AMPM: &[Item<'static>] = &[ - num0(Hour12), - Literal(":"), - num0(Minute), - Literal(":"), - num0(Second), - Space(" "), - fixed(Fixed::UpperAmPm), - ]; - - match remainder.chars().next() { - // we are done - None => None, - - // the next item is a specifier - Some('%') => { - let original = remainder; - remainder = &remainder[1..]; - let mut error_len = 0; - if self.lenient { - error_len += 1; - } - - macro_rules! next { - () => { - match remainder.chars().next() { - Some(x) => { - remainder = &remainder[x.len_utf8()..]; - if self.lenient { - error_len += x.len_utf8(); - } - x - } - None => return Some(self.error(original, &mut error_len, None)), // premature end of string - } - }; - } - - let spec = next!(); - let pad_override = match spec { - '-' => Some(Pad::None), - '0' => Some(Pad::Zero), - '_' => Some(Pad::Space), - _ => None, - }; - let is_alternate = spec == '#'; - let spec = if pad_override.is_some() || is_alternate { next!() } else { spec }; - if is_alternate && !HAVE_ALTERNATES.contains(spec) { - return Some(self.error(original, &mut error_len, Some(spec))); - } - - macro_rules! queue { - [$head:expr, $($tail:expr),+ $(,)*] => ({ - const QUEUE: &'static [Item<'static>] = &[$($tail),+]; - self.queue = QUEUE; - $head - }) - } - #[cfg(not(feature = "unstable-locales"))] - macro_rules! queue_from_slice { - ($slice:expr) => {{ - self.queue = &$slice[1..]; - $slice[0].clone() - }}; - } - - let item = match spec { - 'A' => fixed(Fixed::LongWeekdayName), - 'B' => fixed(Fixed::LongMonthName), - 'C' => num0(YearDiv100), - 'D' => { - queue![num0(Month), Literal("/"), num0(Day), Literal("/"), num0(YearMod100)] - } - 'F' => queue![num0(Year), Literal("-"), num0(Month), Literal("-"), num0(Day)], - 'G' => num0(IsoYear), - 'H' => num0(Hour), - 'I' => num0(Hour12), - 'M' => num0(Minute), - 'P' => fixed(Fixed::LowerAmPm), - 'R' => queue![num0(Hour), Literal(":"), num0(Minute)], - 'S' => num0(Second), - 'T' => { - queue![num0(Hour), Literal(":"), num0(Minute), Literal(":"), num0(Second)] - } - 'U' => num0(WeekFromSun), - 'V' => num0(IsoWeek), - 'W' => num0(WeekFromMon), - #[cfg(not(feature = "unstable-locales"))] - 'X' => queue_from_slice!(T_FMT), - #[cfg(feature = "unstable-locales")] - 'X' => self.switch_to_locale_str(locales::t_fmt, T_FMT), - 'Y' => num0(Year), - 'Z' => fixed(Fixed::TimezoneName), - 'a' => fixed(Fixed::ShortWeekdayName), - 'b' | 'h' => fixed(Fixed::ShortMonthName), - #[cfg(not(feature = "unstable-locales"))] - 'c' => queue_from_slice!(D_T_FMT), - #[cfg(feature = "unstable-locales")] - 'c' => self.switch_to_locale_str(locales::d_t_fmt, D_T_FMT), - 'd' => num0(Day), - 'e' => nums(Day), - 'f' => num0(Nanosecond), - 'g' => num0(IsoYearMod100), - 'j' => num0(Ordinal), - 'k' => nums(Hour), - 'l' => nums(Hour12), - 'm' => num0(Month), - 'n' => Space("\n"), - 'p' => fixed(Fixed::UpperAmPm), - 'q' => num(Quarter), - #[cfg(not(feature = "unstable-locales"))] - 'r' => queue_from_slice!(T_FMT_AMPM), - #[cfg(feature = "unstable-locales")] - 'r' => { - if self.locale.is_some() - && locales::t_fmt_ampm(self.locale.unwrap()).is_empty() - { - // 12-hour clock not supported by this locale. Switch to 24-hour format. - self.switch_to_locale_str(locales::t_fmt, T_FMT) - } else { - self.switch_to_locale_str(locales::t_fmt_ampm, T_FMT_AMPM) - } - } - 's' => num(Timestamp), - 't' => Space("\t"), - 'u' => num(WeekdayFromMon), - 'v' => { - queue![ - nums(Day), - Literal("-"), - fixed(Fixed::ShortMonthName), - Literal("-"), - num0(Year) - ] - } - 'w' => num(NumDaysFromSun), - #[cfg(not(feature = "unstable-locales"))] - 'x' => queue_from_slice!(D_FMT), - #[cfg(feature = "unstable-locales")] - 'x' => self.switch_to_locale_str(locales::d_fmt, D_FMT), - 'y' => num0(YearMod100), - 'z' => { - if is_alternate { - internal_fixed(TimezoneOffsetPermissive) - } else { - fixed(Fixed::TimezoneOffset) - } - } - '+' => fixed(Fixed::RFC3339), - ':' => { - if remainder.starts_with("::z") { - remainder = &remainder[3..]; - fixed(Fixed::TimezoneOffsetTripleColon) - } else if remainder.starts_with(":z") { - remainder = &remainder[2..]; - fixed(Fixed::TimezoneOffsetDoubleColon) - } else if remainder.starts_with('z') { - remainder = &remainder[1..]; - fixed(Fixed::TimezoneOffsetColon) - } else { - self.error(original, &mut error_len, None).1 - } - } - '.' => match next!() { - '3' => match next!() { - 'f' => fixed(Fixed::Nanosecond3), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '6' => match next!() { - 'f' => fixed(Fixed::Nanosecond6), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '9' => match next!() { - 'f' => fixed(Fixed::Nanosecond9), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - 'f' => fixed(Fixed::Nanosecond), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '3' => match next!() { - 'f' => internal_fixed(Nanosecond3NoDot), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '6' => match next!() { - 'f' => internal_fixed(Nanosecond6NoDot), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '9' => match next!() { - 'f' => internal_fixed(Nanosecond9NoDot), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }, - '%' => Literal("%"), - c => { - let res = self.error(original, &mut error_len, Some(c)); - remainder = res.0; - res.1 - } - }; - - // Adjust `item` if we have any padding modifier. - // Not allowed on non-numeric items or on specifiers composed out of multiple - // formatting items. - if let Some(new_pad) = pad_override { - match item { - Item::Numeric(ref kind, _pad) if self.queue.is_empty() => { - Some((remainder, Item::Numeric(kind.clone(), new_pad))) - } - _ => Some(self.error(original, &mut error_len, None)), - } - } else { - Some((remainder, item)) - } - } - - // the next item is space - Some(c) if c.is_whitespace() => { - // `%` is not a whitespace, so `c != '%'` is redundant - let nextspec = - remainder.find(|c: char| !c.is_whitespace()).unwrap_or(remainder.len()); - assert!(nextspec > 0); - let item = Space(&remainder[..nextspec]); - remainder = &remainder[nextspec..]; - Some((remainder, item)) - } - - // the next item is literal - _ => { - let nextspec = remainder - .find(|c: char| c.is_whitespace() || c == '%') - .unwrap_or(remainder.len()); - assert!(nextspec > 0); - let item = Literal(&remainder[..nextspec]); - remainder = &remainder[nextspec..]; - Some((remainder, item)) - } - } - } - - #[cfg(feature = "unstable-locales")] - fn switch_to_locale_str( - &mut self, - localized_fmt_str: impl Fn(Locale) -> &'static str, - fallback: &'static [Item<'static>], - ) -> Item<'a> { - if let Some(locale) = self.locale { - assert!(self.locale_str.is_empty()); - let (fmt_str, item) = self.parse_next_item(localized_fmt_str(locale)).unwrap(); - self.locale_str = fmt_str; - item - } else { - self.queue = &fallback[1..]; - fallback[0].clone() - } - } -} +const HAVE_ALTERNATES: &str = "z"; #[cfg(test)] mod tests { @@ -1246,4 +1199,18 @@ mod tests { "2014-05-07T12:34:56+0000%Q%.2f%%" ); } + + /// Regression test for https://github.com/chronotope/chrono/issues/1725 + #[test] + #[cfg(any(feature = "alloc", feature = "std"))] + fn test_finite() { + let mut i = 0; + for item in StrftimeItems::new("%2f") { + println!("{:?}", item); + i += 1; + if i > 10 { + panic!("infinite loop"); + } + } + } } diff --git a/patch/chrono-0.4.41/src/lib.rs b/patch/chrono-0.4.42/src/lib.rs similarity index 98% rename from patch/chrono-0.4.41/src/lib.rs rename to patch/chrono-0.4.42/src/lib.rs index e3c607b..32a1b82 100644 --- a/patch/chrono-0.4.41/src/lib.rs +++ b/patch/chrono-0.4.42/src/lib.rs @@ -380,7 +380,7 @@ //! use chrono::{DateTime, Utc}; //! //! // Construct a datetime from epoch: -//! let dt: DateTime = DateTime::from_timestamp(1_500_000_000, 0).unwrap(); +//! let dt: DateTime = DateTime::from_timestamp_secs(1_500_000_000).unwrap(); //! assert_eq!(dt.to_rfc2822(), "Fri, 14 Jul 2017 02:40:00 +0000"); //! //! // Get epoch value from a datetime: @@ -512,8 +512,8 @@ extern crate alloc; mod time_delta; -#[cfg(feature = "std")] #[doc(no_inline)] +#[cfg(any(feature = "std", feature = "core-error"))] pub use time_delta::OutOfRangeError; pub use time_delta::TimeDelta; @@ -644,7 +644,7 @@ pub mod serde { /// Zero-copy serialization/deserialization with rkyv. /// /// This module re-exports the `Archived*` versions of chrono's types. -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] pub mod rkyv { pub use crate::datetime::ArchivedDateTime; pub use crate::month::ArchivedMonth; @@ -690,6 +690,9 @@ impl fmt::Debug for OutOfRange { #[cfg(feature = "std")] impl std::error::Error for OutOfRange {} +#[cfg(all(not(feature = "std"), feature = "core-error"))] +impl core::error::Error for OutOfRange {} + /// Workaround because `?` is not (yet) available in const context. #[macro_export] #[doc(hidden)] diff --git a/patch/chrono-0.4.41/src/month.rs b/patch/chrono-0.4.42/src/month.rs similarity index 98% rename from patch/chrono-0.4.41/src/month.rs rename to patch/chrono-0.4.42/src/month.rs index 22ee719..3d12115 100644 --- a/patch/chrono-0.4.41/src/month.rs +++ b/patch/chrono-0.4.42/src/month.rs @@ -1,6 +1,6 @@ use core::fmt; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; use crate::OutOfRange; @@ -31,7 +31,7 @@ use crate::naive::NaiveDate; // Actual implementation is zero-indexed, API intended as 1-indexed for more intuitive behavior. #[derive(PartialEq, Eq, Copy, Clone, Debug, Hash, PartialOrd, Ord)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) @@ -272,6 +272,9 @@ pub struct ParseMonthError { #[cfg(feature = "std")] impl std::error::Error for ParseMonthError {} +#[cfg(all(not(feature = "std"), feature = "core-error"))] +impl core::error::Error for ParseMonthError {} + impl fmt::Display for ParseMonthError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ParseMonthError {{ .. }}") diff --git a/patch/chrono-0.4.41/src/naive/date/mod.rs b/patch/chrono-0.4.42/src/naive/date/mod.rs similarity index 97% rename from patch/chrono-0.4.41/src/naive/date/mod.rs rename to patch/chrono-0.4.42/src/naive/date/mod.rs index bd325c0..a0d410d 100644 --- a/patch/chrono-0.4.41/src/naive/date/mod.rs +++ b/patch/chrono-0.4.42/src/naive/date/mod.rs @@ -20,13 +20,15 @@ use core::num::NonZeroI32; use core::ops::{Add, AddAssign, Sub, SubAssign}; use core::{fmt, str}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; /// L10n locales. #[cfg(all(feature = "unstable-locales", feature = "alloc"))] use pure_rust_locales::Locale; +use super::internals::{Mdf, YearFlags}; +use crate::datetime::UNIX_EPOCH_DAY; #[cfg(feature = "alloc")] use crate::format::DelayedFormat; use crate::format::{ @@ -38,8 +40,6 @@ use crate::naive::{Days, IsoWeek, NaiveDateTime, NaiveTime, NaiveWeek}; use crate::{Datelike, TimeDelta, Weekday}; use crate::{expect, try_opt}; -use super::internals::{Mdf, YearFlags}; - #[cfg(test)] mod tests; @@ -93,7 +93,7 @@ mod tests; /// [proleptic Gregorian date]: crate::NaiveDate#calendar-date #[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) @@ -384,6 +384,35 @@ impl NaiveDate { NaiveDate::from_ordinal_and_flags(year_div_400 * 400 + year_mod_400 as i32, ordinal, flags) } + /// Makes a new `NaiveDate` from a day's number in the proleptic Gregorian calendar, with + /// January 1, 1970 being day 0. + /// + /// # Errors + /// + /// Returns `None` if the date is out of range. + /// + /// # Example + /// + /// ``` + /// use chrono::NaiveDate; + /// + /// let from_ndays_opt = NaiveDate::from_epoch_days; + /// let from_ymd = |y, m, d| NaiveDate::from_ymd_opt(y, m, d).unwrap(); + /// + /// assert_eq!(from_ndays_opt(-719_162), Some(from_ymd(1, 1, 1))); + /// assert_eq!(from_ndays_opt(1), Some(from_ymd(1970, 1, 2))); + /// assert_eq!(from_ndays_opt(0), Some(from_ymd(1970, 1, 1))); + /// assert_eq!(from_ndays_opt(-1), Some(from_ymd(1969, 12, 31))); + /// assert_eq!(from_ndays_opt(13036), Some(from_ymd(2005, 9, 10))); + /// assert_eq!(from_ndays_opt(100_000_000), None); + /// assert_eq!(from_ndays_opt(-100_000_000), None); + /// ``` + #[must_use] + pub const fn from_epoch_days(days: i32) -> Option { + let ce_days = try_opt!(days.checked_add(UNIX_EPOCH_DAY as i32)); + NaiveDate::from_num_days_from_ce_opt(ce_days) + } + /// Makes a new `NaiveDate` by counting the number of occurrences of a particular day-of-week /// since the beginning of the given month. For instance, if you want the 2nd Friday of March /// 2017, you would use `NaiveDate::from_weekday_of_month(2017, 3, Weekday::Fri, 2)`. @@ -1407,6 +1436,23 @@ impl NaiveDate { ndays + self.ordinal() as i32 } + /// Counts the days in the proleptic Gregorian calendar, with January 1, Year 1970 as day 0. + /// + /// # Example + /// + /// ``` + /// use chrono::NaiveDate; + /// + /// let from_ymd = |y, m, d| NaiveDate::from_ymd_opt(y, m, d).unwrap(); + /// + /// assert_eq!(from_ymd(1, 1, 1).to_epoch_days(), -719162); + /// assert_eq!(from_ymd(1970, 1, 1).to_epoch_days(), 0); + /// assert_eq!(from_ymd(2005, 9, 10).to_epoch_days(), 13036); + /// ``` + pub const fn to_epoch_days(&self) -> i32 { + self.num_days_from_ce() - UNIX_EPOCH_DAY as i32 + } + /// Create a new `NaiveDate` from a raw year-ordinal-flags `i32`. /// /// In a valid value an ordinal is never `0`, and neither are the year flags. This method diff --git a/patch/chrono-0.4.41/src/naive/date/tests.rs b/patch/chrono-0.4.42/src/naive/date/tests.rs similarity index 93% rename from patch/chrono-0.4.41/src/naive/date/tests.rs rename to patch/chrono-0.4.42/src/naive/date/tests.rs index 516e4cc..0b47ca4 100644 --- a/patch/chrono-0.4.41/src/naive/date/tests.rs +++ b/patch/chrono-0.4.42/src/naive/date/tests.rs @@ -301,6 +301,39 @@ fn test_date_from_num_days_from_ce() { assert_eq!(from_ndays_from_ce(i32::MAX), None); } +#[test] +fn test_date_from_epoch_days() { + let from_epoch_days = NaiveDate::from_epoch_days; + assert_eq!(from_epoch_days(-719_162), Some(NaiveDate::from_ymd_opt(1, 1, 1).unwrap())); + assert_eq!(from_epoch_days(0), Some(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap())); + assert_eq!(from_epoch_days(1), Some(NaiveDate::from_ymd_opt(1970, 1, 2).unwrap())); + assert_eq!(from_epoch_days(2), Some(NaiveDate::from_ymd_opt(1970, 1, 3).unwrap())); + assert_eq!(from_epoch_days(30), Some(NaiveDate::from_ymd_opt(1970, 1, 31).unwrap())); + assert_eq!(from_epoch_days(31), Some(NaiveDate::from_ymd_opt(1970, 2, 1).unwrap())); + assert_eq!(from_epoch_days(58), Some(NaiveDate::from_ymd_opt(1970, 2, 28).unwrap())); + assert_eq!(from_epoch_days(59), Some(NaiveDate::from_ymd_opt(1970, 3, 1).unwrap())); + assert_eq!(from_epoch_days(364), Some(NaiveDate::from_ymd_opt(1970, 12, 31).unwrap())); + assert_eq!(from_epoch_days(365), Some(NaiveDate::from_ymd_opt(1971, 1, 1).unwrap())); + assert_eq!(from_epoch_days(365 * 2), Some(NaiveDate::from_ymd_opt(1972, 1, 1).unwrap())); + assert_eq!(from_epoch_days(365 * 3 + 1), Some(NaiveDate::from_ymd_opt(1973, 1, 1).unwrap())); + assert_eq!(from_epoch_days(365 * 4 + 1), Some(NaiveDate::from_ymd_opt(1974, 1, 1).unwrap())); + assert_eq!(from_epoch_days(13036), Some(NaiveDate::from_ymd_opt(2005, 9, 10).unwrap())); + assert_eq!(from_epoch_days(-365), Some(NaiveDate::from_ymd_opt(1969, 1, 1).unwrap())); + assert_eq!(from_epoch_days(-366), Some(NaiveDate::from_ymd_opt(1968, 12, 31).unwrap())); + + for days in (-9999..10001).map(|x| x * 100) { + assert_eq!(from_epoch_days(days).map(|d| d.to_epoch_days()), Some(days)); + } + + assert_eq!(from_epoch_days(NaiveDate::MIN.to_epoch_days()), Some(NaiveDate::MIN)); + assert_eq!(from_epoch_days(NaiveDate::MIN.to_epoch_days() - 1), None); + assert_eq!(from_epoch_days(NaiveDate::MAX.to_epoch_days()), Some(NaiveDate::MAX)); + assert_eq!(from_epoch_days(NaiveDate::MAX.to_epoch_days() + 1), None); + + assert_eq!(from_epoch_days(i32::MIN), None); + assert_eq!(from_epoch_days(i32::MAX), None); +} + #[test] fn test_date_from_weekday_of_month_opt() { let ymwd = NaiveDate::from_weekday_of_month_opt; @@ -423,6 +456,18 @@ fn test_date_num_days_from_ce() { } } +#[test] +fn test_date_to_epoch_days() { + assert_eq!(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().to_epoch_days(), 0); + + for year in -9999..10001 { + assert_eq!( + NaiveDate::from_ymd_opt(year, 1, 1).unwrap().to_epoch_days(), + NaiveDate::from_ymd_opt(year - 1, 12, 31).unwrap().to_epoch_days() + 1 + ); + } +} + #[test] fn test_date_succ() { let ymd = |y, m, d| NaiveDate::from_ymd_opt(y, m, d).unwrap(); diff --git a/patch/chrono-0.4.41/src/naive/datetime/mod.rs b/patch/chrono-0.4.42/src/naive/datetime/mod.rs similarity index 99% rename from patch/chrono-0.4.41/src/naive/datetime/mod.rs rename to patch/chrono-0.4.42/src/naive/datetime/mod.rs index 856e8eb..3441a91 100644 --- a/patch/chrono-0.4.41/src/naive/datetime/mod.rs +++ b/patch/chrono-0.4.42/src/naive/datetime/mod.rs @@ -10,7 +10,7 @@ use core::ops::{Add, AddAssign, Sub, SubAssign}; use core::time::Duration; use core::{fmt, str}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; #[cfg(feature = "alloc")] @@ -66,7 +66,7 @@ pub const MAX_DATETIME: NaiveDateTime = NaiveDateTime::MAX; /// ``` #[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) diff --git a/patch/chrono-0.4.41/src/naive/datetime/serde.rs b/patch/chrono-0.4.42/src/naive/datetime/serde.rs similarity index 99% rename from patch/chrono-0.4.41/src/naive/datetime/serde.rs rename to patch/chrono-0.4.42/src/naive/datetime/serde.rs index 85fcb94..6ebdefa 100644 --- a/patch/chrono-0.4.41/src/naive/datetime/serde.rs +++ b/patch/chrono-0.4.42/src/naive/datetime/serde.rs @@ -955,7 +955,7 @@ pub mod ts_seconds { /// } /// /// let my_s: S = serde_json::from_str(r#"{ "time": 1431684000 }"#)?; - /// let expected = DateTime::from_timestamp(1431684000, 0).unwrap().naive_utc(); + /// let expected = DateTime::from_timestamp_secs(1431684000).unwrap().naive_utc(); /// assert_eq!(my_s, S { time: expected }); /// # Ok::<(), serde_json::Error>(()) /// ``` @@ -979,7 +979,7 @@ pub mod ts_seconds { where E: de::Error, { - DateTime::from_timestamp(value, 0) + DateTime::from_timestamp_secs(value) .map(|dt| dt.naive_utc()) .ok_or_else(|| invalid_ts(value)) } @@ -991,7 +991,7 @@ pub mod ts_seconds { if value > i64::MAX as u64 { Err(invalid_ts(value)) } else { - DateTime::from_timestamp(value as i64, 0) + DateTime::from_timestamp_secs(value as i64) .map(|dt| dt.naive_utc()) .ok_or_else(|| invalid_ts(value)) } @@ -1080,7 +1080,7 @@ pub mod ts_seconds_option { /// } /// /// let my_s: S = serde_json::from_str(r#"{ "time": 1431684000 }"#)?; - /// let expected = DateTime::from_timestamp(1431684000, 0).unwrap().naive_utc(); + /// let expected = DateTime::from_timestamp_secs(1431684000).unwrap().naive_utc(); /// assert_eq!(my_s, S { time: Some(expected) }); /// # Ok::<(), serde_json::Error>(()) /// ``` diff --git a/patch/chrono-0.4.41/src/naive/datetime/tests.rs b/patch/chrono-0.4.42/src/naive/datetime/tests.rs similarity index 100% rename from patch/chrono-0.4.41/src/naive/datetime/tests.rs rename to patch/chrono-0.4.42/src/naive/datetime/tests.rs diff --git a/patch/chrono-0.4.41/src/naive/internals.rs b/patch/chrono-0.4.42/src/naive/internals.rs similarity index 100% rename from patch/chrono-0.4.41/src/naive/internals.rs rename to patch/chrono-0.4.42/src/naive/internals.rs diff --git a/patch/chrono-0.4.41/src/naive/isoweek.rs b/patch/chrono-0.4.42/src/naive/isoweek.rs similarity index 95% rename from patch/chrono-0.4.41/src/naive/isoweek.rs rename to patch/chrono-0.4.42/src/naive/isoweek.rs index 1788af7..efe0722 100644 --- a/patch/chrono-0.4.41/src/naive/isoweek.rs +++ b/patch/chrono-0.4.42/src/naive/isoweek.rs @@ -7,7 +7,7 @@ use core::fmt; use super::internals::YearFlags; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; /// ISO 8601 week. @@ -18,7 +18,7 @@ use rkyv::{Archive, Deserialize, Serialize}; /// via the [`Datelike::iso_week`](../trait.Datelike.html#tymethod.iso_week) method. #[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) @@ -176,13 +176,13 @@ mod tests { assert_eq!(minweek.week(), 1); assert_eq!(minweek.week0(), 0); #[cfg(feature = "alloc")] - assert_eq!(format!("{:?}", minweek), NaiveDate::MIN.format("%G-W%V").to_string()); + assert_eq!(format!("{minweek:?}"), NaiveDate::MIN.format("%G-W%V").to_string()); assert_eq!(maxweek.year(), date::MAX_YEAR + 1); assert_eq!(maxweek.week(), 1); assert_eq!(maxweek.week0(), 0); #[cfg(feature = "alloc")] - assert_eq!(format!("{:?}", maxweek), NaiveDate::MAX.format("%G-W%V").to_string()); + assert_eq!(format!("{maxweek:?}"), NaiveDate::MAX.format("%G-W%V").to_string()); } #[test] diff --git a/patch/chrono-0.4.41/src/naive/mod.rs b/patch/chrono-0.4.42/src/naive/mod.rs similarity index 100% rename from patch/chrono-0.4.41/src/naive/mod.rs rename to patch/chrono-0.4.42/src/naive/mod.rs diff --git a/patch/chrono-0.4.41/src/naive/time/mod.rs b/patch/chrono-0.4.42/src/naive/time/mod.rs similarity index 99% rename from patch/chrono-0.4.41/src/naive/time/mod.rs rename to patch/chrono-0.4.42/src/naive/time/mod.rs index 4a6f193..dbf1db9 100644 --- a/patch/chrono-0.4.41/src/naive/time/mod.rs +++ b/patch/chrono-0.4.42/src/naive/time/mod.rs @@ -9,7 +9,7 @@ use core::ops::{Add, AddAssign, Sub, SubAssign}; use core::time::Duration; use core::{fmt, str}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; #[cfg(feature = "alloc")] @@ -211,7 +211,7 @@ mod tests; /// **there is absolutely no guarantee that the leap second read has actually happened**. #[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) diff --git a/patch/chrono-0.4.41/src/naive/time/serde.rs b/patch/chrono-0.4.42/src/naive/time/serde.rs similarity index 100% rename from patch/chrono-0.4.41/src/naive/time/serde.rs rename to patch/chrono-0.4.42/src/naive/time/serde.rs diff --git a/patch/chrono-0.4.41/src/naive/time/tests.rs b/patch/chrono-0.4.42/src/naive/time/tests.rs similarity index 96% rename from patch/chrono-0.4.41/src/naive/time/tests.rs rename to patch/chrono-0.4.42/src/naive/time/tests.rs index a8754ae..d1df20c 100644 --- a/patch/chrono-0.4.41/src/naive/time/tests.rs +++ b/patch/chrono-0.4.42/src/naive/time/tests.rs @@ -283,26 +283,23 @@ fn test_time_from_str() { "23:59:60.373929310237", ]; for &s in &valid { - eprintln!("test_time_parse_from_str valid {:?}", s); + eprintln!("test_time_parse_from_str valid {s:?}"); let d = match s.parse::() { Ok(d) => d, - Err(e) => panic!("parsing `{}` has failed: {}", s, e), + Err(e) => panic!("parsing `{s}` has failed: {e}"), }; - let s_ = format!("{:?}", d); + let s_ = format!("{d:?}"); // `s` and `s_` may differ, but `s.parse()` and `s_.parse()` must be same let d_ = match s_.parse::() { Ok(d) => d, Err(e) => { - panic!("`{}` is parsed into `{:?}`, but reparsing that has failed: {}", s, d, e) + panic!("`{s}` is parsed into `{d:?}`, but reparsing that has failed: {e}") } }; assert!( d == d_, - "`{}` is parsed into `{:?}`, but reparsed result \ - `{:?}` does not match", - s, - d, - d_ + "`{s}` is parsed into `{d:?}`, but reparsed result \ + `{d_:?}` does not match" ); } @@ -329,7 +326,7 @@ fn test_time_from_str() { "09:08:00000000007", // invalid second / invalid fraction format ]; for &s in &invalid { - eprintln!("test_time_parse_from_str invalid {:?}", s); + eprintln!("test_time_parse_from_str invalid {s:?}"); assert!(s.parse::().is_err()); } } diff --git a/patch/chrono-0.4.41/src/offset/fixed.rs b/patch/chrono-0.4.42/src/offset/fixed.rs similarity index 97% rename from patch/chrono-0.4.41/src/offset/fixed.rs rename to patch/chrono-0.4.42/src/offset/fixed.rs index e1fbbf0..2c04537 100644 --- a/patch/chrono-0.4.41/src/offset/fixed.rs +++ b/patch/chrono-0.4.42/src/offset/fixed.rs @@ -6,7 +6,7 @@ use core::fmt; use core::str::FromStr; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; use super::{MappedLocalTime, Offset, TimeZone}; @@ -21,7 +21,7 @@ use crate::naive::{NaiveDate, NaiveDateTime}; /// [`west_opt`](#method.west_opt) methods for examples. #[derive(PartialEq, Eq, Hash, Copy, Clone)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, Hash, Debug))) diff --git a/patch/chrono-0.4.41/src/offset/local/mod.rs b/patch/chrono-0.4.42/src/offset/local/mod.rs similarity index 98% rename from patch/chrono-0.4.41/src/offset/local/mod.rs rename to patch/chrono-0.4.42/src/offset/local/mod.rs index cfe3faf..ee68571 100644 --- a/patch/chrono-0.4.41/src/offset/local/mod.rs +++ b/patch/chrono-0.4.42/src/offset/local/mod.rs @@ -6,7 +6,7 @@ #[cfg(windows)] use std::cmp::Ordering; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; use super::fixed::FixedOffset; @@ -115,10 +115,10 @@ mod tz_info; /// ``` #[derive(Copy, Clone, Debug)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), - rkyv(compare(PartialEq)), - rkyv(attr(derive(Clone, Copy, Debug))) + archive(compare(PartialEq)), + archive_attr(derive(Clone, Copy, Debug)) )] #[cfg_attr(feature = "rkyv-validation", archive(check_bytes))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -219,7 +219,7 @@ impl Transition { #[cfg(windows)] impl PartialOrd for Transition { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.transition_utc.cmp(&other.transition_utc)) + Some(self.cmp(other)) } } @@ -343,8 +343,7 @@ mod tests { // but there are only two sensible options. assert!( timestr == "15:02:60" || timestr == "15:03:00", - "unexpected timestr {:?}", - timestr + "unexpected timestr {timestr:?}" ); } @@ -352,8 +351,7 @@ mod tests { let timestr = dt.time().to_string(); assert!( timestr == "15:02:03.234" || timestr == "15:02:04.234", - "unexpected timestr {:?}", - timestr + "unexpected timestr {timestr:?}" ); } } diff --git a/patch/chrono-0.4.41/src/offset/local/tz_data.rs b/patch/chrono-0.4.42/src/offset/local/tz_data.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/tz_data.rs rename to patch/chrono-0.4.42/src/offset/local/tz_data.rs diff --git a/patch/chrono-0.4.41/src/offset/local/tz_info/mod.rs b/patch/chrono-0.4.42/src/offset/local/tz_info/mod.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/tz_info/mod.rs rename to patch/chrono-0.4.42/src/offset/local/tz_info/mod.rs diff --git a/patch/chrono-0.4.41/src/offset/local/tz_info/parser.rs b/patch/chrono-0.4.42/src/offset/local/tz_info/parser.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/tz_info/parser.rs rename to patch/chrono-0.4.42/src/offset/local/tz_info/parser.rs diff --git a/patch/chrono-0.4.41/src/offset/local/tz_info/rule.rs b/patch/chrono-0.4.42/src/offset/local/tz_info/rule.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/tz_info/rule.rs rename to patch/chrono-0.4.42/src/offset/local/tz_info/rule.rs diff --git a/patch/chrono-0.4.41/src/offset/local/tz_info/timezone.rs b/patch/chrono-0.4.42/src/offset/local/tz_info/timezone.rs similarity index 99% rename from patch/chrono-0.4.41/src/offset/local/tz_info/timezone.rs rename to patch/chrono-0.4.42/src/offset/local/tz_info/timezone.rs index a25be5c..7749799 100644 --- a/patch/chrono-0.4.41/src/offset/local/tz_info/timezone.rs +++ b/patch/chrono-0.4.42/src/offset/local/tz_info/timezone.rs @@ -134,7 +134,7 @@ impl TimeZone { } /// Returns a reference to the time zone - fn as_ref(&'_ self) -> TimeZoneRef<'_> { + fn as_ref(&self) -> TimeZoneRef<'_> { TimeZoneRef { transitions: &self.transitions, local_time_types: &self.local_time_types, diff --git a/patch/chrono-0.4.41/src/offset/local/unix.rs b/patch/chrono-0.4.42/src/offset/local/unix.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/unix.rs rename to patch/chrono-0.4.42/src/offset/local/unix.rs diff --git a/patch/chrono-0.4.41/src/offset/local/win_bindings.rs b/patch/chrono-0.4.42/src/offset/local/win_bindings.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/win_bindings.rs rename to patch/chrono-0.4.42/src/offset/local/win_bindings.rs diff --git a/patch/chrono-0.4.41/src/offset/local/win_bindings.txt b/patch/chrono-0.4.42/src/offset/local/win_bindings.txt similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/win_bindings.txt rename to patch/chrono-0.4.42/src/offset/local/win_bindings.txt diff --git a/patch/chrono-0.4.41/src/offset/local/windows.rs b/patch/chrono-0.4.42/src/offset/local/windows.rs similarity index 100% rename from patch/chrono-0.4.41/src/offset/local/windows.rs rename to patch/chrono-0.4.42/src/offset/local/windows.rs diff --git a/patch/chrono-0.4.41/src/offset/mod.rs b/patch/chrono-0.4.42/src/offset/mod.rs similarity index 99% rename from patch/chrono-0.4.41/src/offset/mod.rs rename to patch/chrono-0.4.42/src/offset/mod.rs index 3cd2fd6..38380c1 100644 --- a/patch/chrono-0.4.41/src/offset/mod.rs +++ b/patch/chrono-0.4.42/src/offset/mod.rs @@ -673,7 +673,7 @@ mod tests { MappedLocalTime::Single(dt) => { assert_eq!(dt.to_string(), *expected); } - e => panic!("Got {:?} instead of an okay answer", e), + e => panic!("Got {e:?} instead of an okay answer"), } } } diff --git a/patch/chrono-0.4.41/src/offset/utc.rs b/patch/chrono-0.4.42/src/offset/utc.rs similarity index 95% rename from patch/chrono-0.4.41/src/offset/utc.rs rename to patch/chrono-0.4.42/src/offset/utc.rs index d8f7b18..be2b520 100644 --- a/patch/chrono-0.4.41/src/offset/utc.rs +++ b/patch/chrono-0.4.42/src/offset/utc.rs @@ -14,7 +14,7 @@ use core::fmt; ))] use std::time::{SystemTime, UNIX_EPOCH}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; use super::{FixedOffset, MappedLocalTime, Offset, TimeZone}; @@ -42,7 +42,7 @@ use crate::{Date, DateTime}; /// ``` #[derive(Copy, Clone, PartialEq, Eq, Hash)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, Debug, Hash))) diff --git a/patch/chrono-0.4.41/src/round.rs b/patch/chrono-0.4.42/src/round.rs similarity index 98% rename from patch/chrono-0.4.41/src/round.rs rename to patch/chrono-0.4.42/src/round.rs index f5ca2d0..9f575a6 100644 --- a/patch/chrono-0.4.41/src/round.rs +++ b/patch/chrono-0.4.42/src/round.rs @@ -109,7 +109,11 @@ pub trait DurationRound: Sized { type Err: std::error::Error; /// Error that can occur in rounding or truncating - #[cfg(not(feature = "std"))] + #[cfg(all(not(feature = "std"), feature = "core-error"))] + type Err: core::error::Error; + + /// Error that can occur in rounding or truncating + #[cfg(all(not(feature = "std"), not(feature = "core-error")))] type Err: fmt::Debug + fmt::Display; /// Return a copy rounded by TimeDelta. @@ -362,6 +366,14 @@ impl std::error::Error for RoundingError { } } +#[cfg(all(not(feature = "std"), feature = "core-error"))] +impl core::error::Error for RoundingError { + #[allow(deprecated)] + fn description(&self) -> &str { + "error from rounding or truncating with DurationRound" + } +} + #[cfg(test)] mod tests { use super::{DurationRound, RoundingError, SubsecRound, TimeDelta}; diff --git a/patch/chrono-0.4.41/src/time_delta.rs b/patch/chrono-0.4.42/src/time_delta.rs similarity index 99% rename from patch/chrono-0.4.41/src/time_delta.rs rename to patch/chrono-0.4.42/src/time_delta.rs index 39950f0..0b467cc 100644 --- a/patch/chrono-0.4.41/src/time_delta.rs +++ b/patch/chrono-0.4.42/src/time_delta.rs @@ -10,6 +10,8 @@ //! Temporal quantification +#[cfg(all(not(feature = "std"), feature = "core-error"))] +use core::error::Error; use core::fmt; use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; use core::time::Duration; @@ -18,7 +20,7 @@ use std::error::Error; use crate::{expect, try_opt}; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; /// The number of nanoseconds in a microsecond. @@ -51,7 +53,7 @@ const SECS_PER_WEEK: i64 = 604_800; /// instance `abs()` can be called without any checks. #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq, PartialOrd)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash))) @@ -630,7 +632,7 @@ impl fmt::Display for OutOfRangeError { } } -#[cfg(feature = "std")] +#[cfg(any(feature = "std", feature = "core-error"))] impl Error for OutOfRangeError { #[allow(deprecated)] fn description(&self) -> &str { diff --git a/patch/chrono-0.4.41/src/traits.rs b/patch/chrono-0.4.42/src/traits.rs similarity index 98% rename from patch/chrono-0.4.41/src/traits.rs rename to patch/chrono-0.4.42/src/traits.rs index ada73b2..450f0a2 100644 --- a/patch/chrono-0.4.41/src/traits.rs +++ b/patch/chrono-0.4.42/src/traits.rs @@ -366,7 +366,7 @@ mod tests { /// /// Panics if `div` is not positive. fn in_between(start: i32, end: i32, div: i32) -> i32 { - assert!(div > 0, "in_between: nonpositive div = {}", div); + assert!(div > 0, "in_between: nonpositive div = {div}"); let start = (start.div_euclid(div), start.rem_euclid(div)); let end = (end.div_euclid(div), end.rem_euclid(div)); // The lowest multiple of `div` greater than or equal to `start`, divided. @@ -390,16 +390,10 @@ mod tests { assert_eq!( jan1_year.num_days_from_ce(), num_days_from_ce(&jan1_year), - "on {:?}", - jan1_year + "on {jan1_year:?}" ); let mid_year = jan1_year + Days::new(133); - assert_eq!( - mid_year.num_days_from_ce(), - num_days_from_ce(&mid_year), - "on {:?}", - mid_year - ); + assert_eq!(mid_year.num_days_from_ce(), num_days_from_ce(&mid_year), "on {mid_year:?}"); } } diff --git a/patch/chrono-0.4.41/src/weekday.rs b/patch/chrono-0.4.42/src/weekday.rs similarity index 97% rename from patch/chrono-0.4.41/src/weekday.rs rename to patch/chrono-0.4.42/src/weekday.rs index 16f2bce..fcb4183 100644 --- a/patch/chrono-0.4.41/src/weekday.rs +++ b/patch/chrono-0.4.42/src/weekday.rs @@ -1,6 +1,6 @@ use core::fmt; -#[cfg(any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] +#[cfg(any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"))] use rkyv::{Archive, Deserialize, Serialize}; use crate::OutOfRange; @@ -31,7 +31,7 @@ use crate::OutOfRange; /// ``` #[derive(PartialEq, Eq, Copy, Clone, Debug, Hash)] #[cfg_attr( - any(feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), + any(feature = "rkyv", feature = "rkyv-16", feature = "rkyv-32", feature = "rkyv-64"), derive(Archive, Deserialize, Serialize), rkyv(compare(PartialEq)), rkyv(attr(derive(Clone, Copy, PartialEq, Eq, Debug, Hash))) @@ -238,6 +238,9 @@ pub struct ParseWeekdayError { pub(crate) _dummy: (), } +#[cfg(all(not(feature = "std"), feature = "core-error"))] +impl core::error::Error for ParseWeekdayError {} + #[cfg(feature = "std")] impl std::error::Error for ParseWeekdayError {} diff --git a/patch/chrono-0.4.41/src/weekday_set.rs b/patch/chrono-0.4.42/src/weekday_set.rs similarity index 100% rename from patch/chrono-0.4.41/src/weekday_set.rs rename to patch/chrono-0.4.42/src/weekday_set.rs diff --git a/patch/dotenvy-0.15.7/src/iter.rs b/patch/dotenvy-0.15.7/src/iter.rs index bfd8793..f7bafc8 100644 --- a/patch/dotenvy-0.15.7/src/iter.rs +++ b/patch/dotenvy-0.15.7/src/iter.rs @@ -11,10 +11,11 @@ use crate::parse; pub struct LoadResult { /// Number of successfully loaded variables pub loaded: usize, - /// Number of variables that were skipped (for `load` method only) - pub skipped: usize, - /// Number of variables that were overridden (for `load_override` method only) - pub overridden: usize, + // /// Number of variables that were skipped (for `load` method only) + // pub skipped: usize, + // /// Number of variables that were overridden (for `load_override` method only) + // pub overridden: usize, + pub skipped_or_overridden: usize, } pub struct Iter { @@ -57,8 +58,7 @@ impl Iter { Ok(LoadResult { loaded, - skipped, - overridden: 0, + skipped_or_overridden: skipped, }) } @@ -86,8 +86,7 @@ impl Iter { Ok(LoadResult { loaded, - skipped: 0, - overridden, + skipped_or_overridden: overridden, }) } diff --git a/patch/macros/Cargo.toml b/patch/macros/Cargo.toml new file mode 100644 index 0000000..c6c71cf --- /dev/null +++ b/patch/macros/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "macros" +version = "0.1.0" +edition = "2024" +authors = ["wisdgod "] +license = "MIT OR Apache-2.0" +description = "A Proto3 file dependency analyzer and optimizer" +repository = "https://github.com/wisdgod/ppp" + +[dependencies] diff --git a/patch/macros/src/lib.rs b/patch/macros/src/lib.rs new file mode 100644 index 0000000..10bab0b --- /dev/null +++ b/patch/macros/src/lib.rs @@ -0,0 +1,105 @@ +/// Batch define constants of the same type with shared attributes. +/// +/// # Examples +/// +/// ``` +/// define_typed_constants! { +/// pub u32 => { +/// MAX_CONNECTIONS = 1024, +/// DEFAULT_TIMEOUT = 30, +/// MIN_BUFFER_SIZE = 256, +/// } +/// +/// #[allow(dead_code)] +/// &'static str => { +/// APP_NAME = "server", +/// VERSION = "1.0.0", +/// } +/// } +/// ``` +#[macro_export] +macro_rules! define_typed_constants { + // Entry point: process type group with first constant + ( + $(#[$group_attr:meta])* + $vis:vis $ty:ty => { + $(#[$attr:meta])* + $name:ident = $value:expr, + $($inner_rest:tt)* + } + $($rest:tt)* + ) => { + $(#[$attr])* + $(#[$group_attr])* + $vis const $name: $ty = $value; + + $crate::define_typed_constants! { + @same_type + $(#[$group_attr])* + $vis $ty => { + $($inner_rest)* + } + } + + $crate::define_typed_constants! { + $($rest)* + } + }; + + // Process remaining constants of the same type + ( + @same_type + $(#[$group_attr:meta])* + $vis:vis $ty:ty => { + $(#[$attr:meta])* + $name:ident = $value:expr, + $($rest:tt)* + } + ) => { + $(#[$attr])* + $(#[$group_attr])* + $vis const $name: $ty = $value; + + $crate::define_typed_constants! { + @same_type + $(#[$group_attr])* + $vis $ty => { + $($rest)* + } + } + }; + + // Last constant in type group (no trailing comma) + ( + @same_type + $(#[$group_attr:meta])* + $vis:vis $ty:ty => { + $(#[$attr:meta])* + $name:ident = $value:expr + } + ) => { + $(#[$attr])* + $(#[$group_attr])* + $vis const $name: $ty = $value; + }; + + // Empty type group + (@same_type $(#[$group_attr:meta])* $vis:vis $ty:ty => {}) => {}; + + // Terminal case + () => {}; +} + +#[macro_export] +macro_rules! transmute_unchecked { + ($x:expr) => { + unsafe { ::core::intrinsics::transmute_unchecked($x) } + }; +} + +#[macro_export] +macro_rules! unwrap_unchecked { + ($x:expr) => { + unsafe { $x.unwrap_unchecked() } + }; +} diff --git a/patch/prost-0.14.1/Cargo.toml b/patch/prost-0.14.1/Cargo.toml new file mode 100644 index 0000000..8c0bceb --- /dev/null +++ b/patch/prost-0.14.1/Cargo.toml @@ -0,0 +1,87 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +rust-version = "1.71.1" +name = "prost" +version = "0.14.1" +authors = [ + "Dan Burkert ", + "Lucio Franco ", + "Casper Meijn ", + "Tokio Contributors ", +] +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "A Protocol Buffers implementation for the Rust Language." +readme = "README.md" +keywords = [ + "protobuf", + "serialization", +] +categories = ["encoding"] +license = "Apache-2.0" +repository = "https://github.com/tokio-rs/prost" + +[features] +default = [ + "derive", + "std", +] +derive = ["dep:prost-derive"] +no-recursion-limit = [] +std = [] +indexmap = ["dep:indexmap"] + +[lib] +name = "prost" +path = "src/lib.rs" +bench = false + +[dependencies.bytes] +version = "1" +default-features = false + +[dependencies.prost-derive] +path = "../prost-derive" +optional = true + +[dependencies.macros] +path = "../macros" + +[dependencies.indexmap] +version = "2" +optional = true + +[dependencies.cfg-if] +version = "1.0" + +[dependencies.any_all_workaround] +version = "0.1" + +[dependencies.serde] +version = "1" +default-features = false + +[dev-dependencies.criterion] +version = "0.7" +default-features = false + +[dev-dependencies.proptest] +version = "1" + +[dev-dependencies.rand] +version = "0.9" diff --git a/patch/prost-0.14.1/Cargo.toml.orig b/patch/prost-0.14.1/Cargo.toml.orig new file mode 100644 index 0000000..138475a --- /dev/null +++ b/patch/prost-0.14.1/Cargo.toml.orig @@ -0,0 +1,35 @@ +[package] +name = "prost" +readme = "README.md" +description = "A Protocol Buffers implementation for the Rust Language." +keywords = ["protobuf", "serialization"] +categories = ["encoding"] +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lib] +# https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options +bench = false + +[features] +default = ["derive", "std"] +derive = ["dep:prost-derive"] +no-recursion-limit = [] +std = [] + +[dependencies] +bytes = { version = "1", default-features = false } +prost-derive = { version = "0.14.1", path = "../prost-derive", optional = true } + +[dev-dependencies] +criterion = { version = "0.7", default-features = false } +proptest = "1" +rand = "0.9" + +[[bench]] +name = "varint" +harness = false diff --git a/patch/prost-0.14.1/LICENSE b/patch/prost-0.14.1/LICENSE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/patch/prost-0.14.1/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/patch/prost-0.14.1/README.md b/patch/prost-0.14.1/README.md new file mode 100644 index 0000000..36dd5f1 --- /dev/null +++ b/patch/prost-0.14.1/README.md @@ -0,0 +1,507 @@ +[![continuous integration](https://github.com/tokio-rs/prost/actions/workflows/ci.yml/badge.svg?branch=master)](https://github.com/tokio-rs/prost/actions/workflows/ci.yml?query=branch%3Amaster) +[![Documentation](https://docs.rs/prost/badge.svg)](https://docs.rs/prost/) +[![Crate](https://img.shields.io/crates/v/prost.svg)](https://crates.io/crates/prost) +[![Dependency Status](https://deps.rs/repo/github/tokio-rs/prost/status.svg)](https://deps.rs/repo/github/tokio-rs/prost) +[![Discord](https://img.shields.io/discord/500028886025895936)](https://discord.gg/tokio) + +# *PROST!* + +`prost` is a [Protocol Buffers](https://developers.google.com/protocol-buffers/) +implementation for the [Rust Language](https://www.rust-lang.org/). `prost` +generates simple, idiomatic Rust code from `proto2` and `proto3` files. + +Compared to other Protocol Buffers implementations, `prost` + +* Generates simple, idiomatic, and readable Rust types by taking advantage of + Rust `derive` attributes. +* Retains comments from `.proto` files in generated Rust code. +* Allows existing Rust types (not generated from a `.proto`) to be serialized + and deserialized by adding attributes. +* Uses the [`bytes::{Buf, BufMut}`](https://github.com/carllerche/bytes) + abstractions for serialization instead of `std::io::{Read, Write}`. +* Respects the Protobuf `package` specifier when organizing generated code + into Rust modules. +* Preserves unknown enum values during deserialization. +* Does not include support for runtime reflection or message descriptors. + +## Using `prost` in a Cargo Project + +First, add `prost` and its public dependencies to your `Cargo.toml`: + +```ignore +[dependencies] +prost = "0.14" +# Only necessary if using Protobuf well-known types: +prost-types = "0.14" +``` + +The recommended way to add `.proto` compilation to a Cargo project is to use the +`prost-build` library. See the [`prost-build` documentation][prost-build] for +more details and examples. + +See the [snazzy repository][snazzy] for a simple start-to-finish example. + +[prost-build]: https://docs.rs/prost-build/latest/prost_build/ +[snazzy]: https://github.com/danburkert/snazzy + +### MSRV + +`prost` follows the `tokio-rs` project's MSRV model and supports 1.70. For more +information on the tokio msrv policy you can check it out [here][tokio msrv] + +[tokio msrv]: https://github.com/tokio-rs/tokio/#supported-rust-versions + +## Generated Code + +`prost` generates Rust code from source `.proto` files using the `proto2` or +`proto3` syntax. `prost`'s goal is to make the generated code as simple as +possible. + +### `protoc` + +With `prost-build` v0.11 release, `protoc` will be required to invoke +`compile_protos` (unless `skip_protoc` is enabled). Prost will no longer provide +bundled `protoc` or attempt to compile `protoc` for users. For install +instructions for `protoc`, please check out the [protobuf install] instructions. + +[protobuf install]: https://github.com/protocolbuffers/protobuf#protobuf-compiler-installation + + +### Packages + +Prost can now generate code for `.proto` files that don't have a package spec. +`prost` will translate the Protobuf package into +a Rust module. For example, given the `package` specifier: + +[package]: https://developers.google.com/protocol-buffers/docs/proto#packages + +```protobuf,ignore +package foo.bar; +``` + +All Rust types generated from the file will be in the `foo::bar` module. + +### Messages + +Given a simple message declaration: + +```protobuf,ignore +// Sample message. +message Foo { +} +``` + +`prost` will generate the following Rust struct: + +```rust,ignore +/// Sample message. +#[derive(Clone, Debug, PartialEq, Message)] +pub struct Foo { +} +``` + +### Fields + +Fields in Protobuf messages are translated into Rust as public struct fields of the +corresponding type. + +#### Scalar Values + +Scalar value types are converted as follows: + +| Protobuf Type | Rust Type | +| --- | --- | +| `double` | `f64` | +| `float` | `f32` | +| `int32` | `i32` | +| `int64` | `i64` | +| `uint32` | `u32` | +| `uint64` | `u64` | +| `sint32` | `i32` | +| `sint64` | `i64` | +| `fixed32` | `u32` | +| `fixed64` | `u64` | +| `sfixed32` | `i32` | +| `sfixed64` | `i64` | +| `bool` | `bool` | +| `string` | `String` | +| `bytes` | `Vec` | + +#### Enumerations + +All `.proto` enumeration types convert to the Rust `i32` type. Additionally, +each enumeration type gets a corresponding Rust `enum` type. For example, this +`proto` enum: + +```protobuf,ignore +enum PhoneType { + MOBILE = 0; + HOME = 1; + WORK = 2; +} +``` + +gets this corresponding Rust enum [^1]: + +```rust,ignore +pub enum PhoneType { + Mobile = 0, + Home = 1, + Work = 2, +} +``` + +[^1]: Annotations have been elided for clarity. See below for a full example. + +You can convert a `PhoneType` value to an `i32` by doing: + +```rust,ignore +PhoneType::Mobile as i32 +``` + +The `#[derive(::prost::Enumeration)]` annotation added to the generated +`PhoneType` adds these associated functions to the type: + +```rust,ignore +impl PhoneType { + pub fn is_valid(value: i32) -> bool { ... } + #[deprecated] + pub fn from_i32(value: i32) -> Option { ... } +} +``` + +It also adds an `impl TryFrom for PhoneType`, so you can convert an `i32` to its corresponding `PhoneType` value by doing, +for example: + +```rust,ignore +let phone_type = 2i32; + +match PhoneType::try_from(phone_type) { + Ok(PhoneType::Mobile) => ..., + Ok(PhoneType::Home) => ..., + Ok(PhoneType::Work) => ..., + Err(_) => ..., +} +``` + +Additionally, wherever a `proto` enum is used as a field in a `Message`, the +message will have 'accessor' methods to get/set the value of the field as the +Rust enum type. For instance, this proto `PhoneNumber` message that has a field +named `type` of type `PhoneType`: + +```protobuf,ignore +message PhoneNumber { + string number = 1; + PhoneType type = 2; +} +``` + +will become the following Rust type [^2] with methods `type` and `set_type`: + +```rust,ignore +pub struct PhoneNumber { + pub number: String, + pub r#type: i32, // the `r#` is needed because `type` is a Rust keyword +} + +impl PhoneNumber { + pub fn r#type(&self) -> PhoneType { ... } + pub fn set_type(&mut self, value: PhoneType) { ... } +} +``` + +Note that the getter methods will return the Rust enum's default value if the +field has an invalid `i32` value. + +The `enum` type isn't used directly as a field, because the Protobuf spec +mandates that enumerations values are 'open', and decoding unrecognized +enumeration values must be possible. + +[^2]: Annotations have been elided for clarity. See below for a full example. + +#### Field Modifiers + +Protobuf scalar value and enumeration message fields can have a modifier +depending on the Protobuf version. Modifiers change the corresponding type of +the Rust field: + +| `.proto` Version | Modifier | Rust Type | +| --- | --- | --- | +| `proto2` | `optional` | `Option` | +| `proto2` | `required` | `T` | +| `proto3` | default | `T` for scalar types, `Option` otherwise | +| `proto3` | `optional` | `Option` | +| `proto2`/`proto3` | `repeated` | `Vec` | + +Note that in `proto3` the default representation for all user-defined message +types is `Option`, and for scalar types just `T` (during decoding, a missing +value is populated by `T::default()`). If you need a witness of the presence of +a scalar type `T`, use the `optional` modifier to enforce an `Option` +representation in the generated Rust struct. + +#### Map Fields + +Map fields are converted to a Rust `HashMap` with key and value type converted +from the Protobuf key and value types. + +#### Message Fields + +Message fields are converted to the corresponding struct type. The table of +field modifiers above applies to message fields, except that `proto3` message +fields without a modifier (the default) will be wrapped in an `Option`. +Typically message fields are unboxed. `prost` will automatically box a message +field if the field type and the parent type are recursively nested in order to +avoid an infinite sized struct. + +#### Oneof Fields + +Oneof fields convert to a Rust enum. Protobuf `oneof`s types are not named, so +`prost` uses the name of the `oneof` field for the resulting Rust enum, and +defines the enum in a module under the struct. For example, a `proto3` message +such as: + +```protobuf,ignore +message Foo { + oneof widget { + int32 quux = 1; + string bar = 2; + } +} +``` + +generates the following Rust[^3]: + +```rust,ignore +pub struct Foo { + pub widget: Option, +} +pub mod foo { + pub enum Widget { + Quux(i32), + Bar(String), + } +} +``` + +`oneof` fields are always wrapped in an `Option`. + +[^3]: Annotations have been elided for clarity. See below for a full example. + +### Services + +`prost-build` allows a custom code-generator to be used for processing `service` +definitions. This can be used to output Rust traits according to an +application's specific needs. + +### Generated Code Example + +Example `.proto` file: + +```protobuf,ignore +syntax = "proto3"; +package tutorial; + +message Person { + string name = 1; + int32 id = 2; // Unique ID number for this person. + string email = 3; + + enum PhoneType { + MOBILE = 0; + HOME = 1; + WORK = 2; + } + + message PhoneNumber { + string number = 1; + PhoneType type = 2; + } + + repeated PhoneNumber phones = 4; +} + +// Our address book file is just one of these. +message AddressBook { + repeated Person people = 1; +} +``` + +and the generated Rust code (`tutorial.rs`): + +```rust,ignore +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Person { + #[prost(string, tag="1")] + pub name: ::prost::alloc::string::String, + /// Unique ID number for this person. + #[prost(int32, tag="2")] + pub id: i32, + #[prost(string, tag="3")] + pub email: ::prost::alloc::string::String, + #[prost(message, repeated, tag="4")] + pub phones: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `Person`. +pub mod person { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct PhoneNumber { + #[prost(string, tag="1")] + pub number: ::prost::alloc::string::String, + #[prost(enumeration="PhoneType", tag="2")] + pub r#type: i32, + } + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum PhoneType { + Mobile = 0, + Home = 1, + Work = 2, + } +} +/// Our address book file is just one of these. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AddressBook { + #[prost(message, repeated, tag="1")] + pub people: ::prost::alloc::vec::Vec, +} +``` + +## Accessing the `protoc` `FileDescriptorSet` + +The `prost_build::Config::file_descriptor_set_path` option can be used to emit a file descriptor set +during the build & code generation step. When used in conjunction with the `std::include_bytes` +macro and the `prost_types::FileDescriptorSet` type, applications and libraries using Prost can +implement introspection capabilities requiring details from the original `.proto` files. + +## Using `prost` in a `no_std` Crate + +`prost` is compatible with `no_std` crates. To enable `no_std` support, disable +the `std` features in `prost` and `prost-types`: + +```ignore +[dependencies] +prost = { version = "0.14.1", default-features = false, features = ["derive"] } +# Only necessary if using Protobuf well-known types: +prost-types = { version = "0.14.1", default-features = false } +``` + +Additionally, configure `prost-build` to output `BTreeMap`s instead of `HashMap`s +for all Protobuf `map` fields in your `build.rs`: + +```rust,ignore +let mut config = prost_build::Config::new(); +config.btree_map(&["."]); +``` + +When using edition 2015, it may be necessary to add an `extern crate core;` +directive to the crate which includes `prost`-generated code. + +## Serializing Existing Types + +`prost` uses a custom derive macro to handle encoding and decoding types, which +means that if your existing Rust type is compatible with Protobuf types, you can +serialize and deserialize it by adding the appropriate derive and field +annotations. + +Currently the best documentation on adding annotations is to look at the +generated code examples above. + +### Tag Inference for Existing Types + +Prost automatically infers tags for the struct. + +Fields are tagged sequentially in the order they +are specified, starting with `1`. + +You may skip tags which have been reserved, or where there are gaps between +sequentially occurring tag values by specifying the tag number to skip to with +the `tag` attribute on the first field after the gap. The following fields will +be tagged sequentially starting from the next number. + +```rust,ignore +use prost; +use prost::{Enumeration, Message}; + +#[derive(Clone, PartialEq, Message)] +struct Person { + #[prost(string, tag = "1")] + pub id: String, // tag=1 + // NOTE: Old "name" field has been removed + // pub name: String, // tag=2 (Removed) + #[prost(string, tag = "6")] + pub given_name: String, // tag=6 + #[prost(string)] + pub family_name: String, // tag=7 + #[prost(string)] + pub formatted_name: String, // tag=8 + #[prost(uint32, tag = "3")] + pub age: u32, // tag=3 + #[prost(uint32)] + pub height: u32, // tag=4 + #[prost(enumeration = "Gender")] + pub gender: i32, // tag=5 + // NOTE: Skip to less commonly occurring fields + #[prost(string, tag = "16")] + pub name_prefix: String, // tag=16 (eg. mr/mrs/ms) + #[prost(string)] + pub name_suffix: String, // tag=17 (eg. jr/esq) + #[prost(string)] + pub maiden_name: String, // tag=18 +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Enumeration)] +pub enum Gender { + Unknown = 0, + Female = 1, + Male = 2, +} +``` + +## Nix + +The prost project maintains flakes support for local development. Once you have +nix and nix flakes setup you can just run `nix develop` to get a shell +configured with the required dependencies to compile the whole project. + +## Feature Flags +- `std`: Enable integration with standard library. Disable this feature for `no_std` support. This feature is enabled by default. +- `derive`: Enable integration with `prost-derive`. Disable this feature to reduce compile times. This feature is enabled by default. +- `prost-derive`: Deprecated. Alias for `derive` feature. +- `no-recursion-limit`: Disable the recursion limit. The recursion limit is 100 and cannot be customized. + +## FAQ + +1. **Could `prost` be implemented as a serializer for [Serde](https://serde.rs/)?** + + Probably not, however I would like to hear from a Serde expert on the matter. + There are two complications with trying to serialize Protobuf messages with + Serde: + + - Protobuf fields require a numbered tag, and currently there appears to be no + mechanism suitable for this in `serde`. + - The mapping of Protobuf type to Rust type is not 1-to-1. As a result, + trait-based approaches to dispatching don't work very well. Example: six + different Protobuf field types correspond to a Rust `Vec`: `repeated + int32`, `repeated sint32`, `repeated sfixed32`, and their packed + counterparts. + + But it is possible to place `serde` derive tags onto the generated types, so + the same structure can support both `prost` and `Serde`. + +2. **I get errors when trying to run `cargo test` on MacOS** + + If the errors are about missing `autoreconf` or similar, you can probably fix + them by running + + ```ignore + brew install automake + brew install libtool + ``` + +## License + +`prost` is distributed under the terms of the Apache License (Version 2.0). + +See [LICENSE](https://github.com/tokio-rs/prost/blob/master/LICENSE) for details. + +Copyright 2022 Dan Burkert & Tokio Contributors diff --git a/patch/prost-0.14.1/src/byte_str.rs b/patch/prost-0.14.1/src/byte_str.rs new file mode 100644 index 0000000..aa89bdf --- /dev/null +++ b/patch/prost-0.14.1/src/byte_str.rs @@ -0,0 +1,366 @@ +use core::borrow::Borrow; +use core::{fmt, ops, str}; +use core::str::pattern::{Pattern, ReverseSearcher, Searcher as _}; +#[cfg(not(feature = "std"))] +use alloc::string::String; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +use bytes::Bytes; + +#[allow(unused)] +struct BytesUnsafeView { + ptr: *const u8, + len: usize, + // inlined "trait object" + data: core::sync::atomic::AtomicPtr<()>, + vtable: &'static Vtable, +} + +#[allow(unused)] +struct Vtable { + /// fn(data, ptr, len) + clone: unsafe fn(&core::sync::atomic::AtomicPtr<()>, *const u8, usize) -> Bytes, + /// fn(data, ptr, len) + /// + /// takes `Bytes` to value + to_vec: unsafe fn(&core::sync::atomic::AtomicPtr<()>, *const u8, usize) -> Vec, + to_mut: unsafe fn(&core::sync::atomic::AtomicPtr<()>, *const u8, usize) -> bytes::BytesMut, + /// fn(data) + is_unique: unsafe fn(&core::sync::atomic::AtomicPtr<()>) -> bool, + /// fn(data, ptr, len) + drop: unsafe fn(&mut core::sync::atomic::AtomicPtr<()>, *const u8, usize), +} + +impl BytesUnsafeView { + #[inline] + const fn from(src: Bytes) -> Self { unsafe { ::core::intrinsics::transmute(src) } } + #[inline] + const fn to(self) -> Bytes { unsafe { ::core::intrinsics::transmute(self) } } +} + +#[repr(transparent)] +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub struct ByteStr { + // Invariant: bytes contains valid UTF-8 + bytes: Bytes, +} + +impl ByteStr { + #[inline] + pub fn new() -> ByteStr { + ByteStr { + // Invariant: the empty slice is trivially valid UTF-8. + bytes: Bytes::new(), + } + } + + #[inline] + pub const fn from_static(val: &'static str) -> ByteStr { + ByteStr { + // Invariant: val is a str so contains valid UTF-8. + bytes: Bytes::from_static(val.as_bytes()), + } + } + + #[inline] + /// ## Panics + /// In a debug build this will panic if `bytes` is not valid UTF-8. + /// + /// ## Safety + /// `bytes` must contain valid UTF-8. In a release build it is undefined + /// behavior to call this with `bytes` that is not valid UTF-8. + pub unsafe fn from_utf8_unchecked(bytes: Bytes) -> ByteStr { + if cfg!(debug_assertions) { + match str::from_utf8(&bytes.as_ref()) { + Ok(_) => (), + Err(err) => panic!( + "ByteStr::from_utf8_unchecked() with invalid bytes; error = {err}, bytes = {bytes:?}", + ), + } + } + // Invariant: assumed by the safety requirements of this function. + ByteStr { bytes } + } + + #[inline(always)] + pub fn from_utf8(bytes: Bytes) -> Result { + str::from_utf8(&bytes)?; + // Invariant: just checked is utf8 + Ok(ByteStr { bytes }) + } + + #[inline] + pub const fn len(&self) -> usize { self.bytes.len() } + + #[must_use] + #[inline(always)] + pub const fn as_bytes(&self) -> &Bytes { &self.bytes } + + #[must_use] + #[inline] + pub unsafe fn slice_unchecked(&self, range: impl core::ops::RangeBounds) -> Self { + use core::ops::Bound; + + let len = self.len(); + + let begin = match range.start_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n + 1, + Bound::Unbounded => 0, + }; + + let end = match range.end_bound() { + Bound::Included(&n) => n + 1, + Bound::Excluded(&n) => n, + Bound::Unbounded => len, + }; + + if end == begin { + return ByteStr::new(); + } + + let mut ret = BytesUnsafeView::from(self.bytes.clone()); + + ret.len = end - begin; + ret.ptr = unsafe { ret.ptr.add(begin) }; + + Self { bytes: ret.to() } + } + + #[inline] + pub fn split_once(&self, delimiter: P) -> Option<(ByteStr, ByteStr)> { + let (start, end) = delimiter.into_searcher(self).next_match()?; + // SAFETY: `Searcher` is known to return valid indices. + unsafe { Some((self.slice_unchecked(..start), self.slice_unchecked(end..))) } + } + + #[inline] + pub fn rsplit_once(&self, delimiter: P) -> Option<(ByteStr, ByteStr)> + where for<'a> P::Searcher<'a>: ReverseSearcher<'a> { + let (start, end) = delimiter.into_searcher(self).next_match_back()?; + // SAFETY: `Searcher` is known to return valid indices. + unsafe { Some((self.slice_unchecked(..start), self.slice_unchecked(end..))) } + } + + #[must_use] + #[inline(always)] + pub const unsafe fn as_bytes_mut(&mut self) -> &mut Bytes { &mut self.bytes } + + #[inline] + pub fn clear(&mut self) { self.bytes.clear() } +} + +unsafe impl Send for ByteStr {} +unsafe impl Sync for ByteStr {} + +impl Clone for ByteStr { + #[inline] + fn clone(&self) -> ByteStr { Self { bytes: self.bytes.clone() } } +} + +impl bytes::Buf for ByteStr { + #[inline] + fn remaining(&self) -> usize { self.bytes.remaining() } + + #[inline] + fn chunk(&self) -> &[u8] { self.bytes.chunk() } + + #[inline] + fn advance(&mut self, cnt: usize) { self.bytes.advance(cnt) } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { self.bytes.copy_to_bytes(len) } +} + +impl fmt::Debug for ByteStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&**self, f) } +} + +impl fmt::Display for ByteStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&**self, f) } +} + +impl ops::Deref for ByteStr { + type Target = str; + + #[inline] + fn deref(&self) -> &str { + let b: &[u8] = self.bytes.as_ref(); + // Safety: the invariant of `bytes` is that it contains valid UTF-8. + unsafe { str::from_utf8_unchecked(b) } + } +} + +impl AsRef for ByteStr { + #[inline] + fn as_ref(&self) -> &str { self } +} + +impl AsRef<[u8]> for ByteStr { + #[inline] + fn as_ref(&self) -> &[u8] { self.bytes.as_ref() } +} + +impl core::hash::Hash for ByteStr { + #[inline] + fn hash(&self, state: &mut H) + where H: core::hash::Hasher { + self.bytes.hash(state) + } +} + +impl Borrow for ByteStr { + #[inline] + fn borrow(&self) -> &str { &**self } +} + +impl Borrow<[u8]> for ByteStr { + #[inline] + fn borrow(&self) -> &[u8] { self.bytes.borrow() } +} + +impl PartialEq for ByteStr { + #[inline] + fn eq(&self, other: &str) -> bool { &**self == other } +} + +impl PartialEq<&str> for ByteStr { + #[inline] + fn eq(&self, other: &&str) -> bool { &**self == *other } +} + +impl PartialEq for str { + #[inline] + fn eq(&self, other: &ByteStr) -> bool { self == &**other } +} + +impl PartialEq for &str { + #[inline] + fn eq(&self, other: &ByteStr) -> bool { *self == &**other } +} + +impl PartialEq for ByteStr { + #[inline] + fn eq(&self, other: &String) -> bool { &**self == other.as_str() } +} + +impl PartialEq<&String> for ByteStr { + #[inline] + fn eq(&self, other: &&String) -> bool { &**self == other.as_str() } +} + +impl PartialEq for String { + #[inline] + fn eq(&self, other: &ByteStr) -> bool { self.as_str() == &**other } +} + +impl PartialEq for &String { + #[inline] + fn eq(&self, other: &ByteStr) -> bool { self.as_str() == &**other } +} + +// impl From + +impl Default for ByteStr { + #[inline] + fn default() -> ByteStr { ByteStr::new() } +} + +impl From for ByteStr { + #[inline] + fn from(src: String) -> ByteStr { + ByteStr { + // Invariant: src is a String so contains valid UTF-8. + bytes: Bytes::from(src), + } + } +} + +impl<'a> From<&'a str> for ByteStr { + #[inline] + fn from(src: &'a str) -> ByteStr { + ByteStr { + // Invariant: src is a str so contains valid UTF-8. + bytes: Bytes::copy_from_slice(src.as_bytes()), + } + } +} + +impl From for Bytes { + #[inline(always)] + fn from(src: ByteStr) -> Self { src.bytes } +} + +impl serde::Serialize for ByteStr { + #[inline] + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + serializer.serialize_str(&**self) + } +} + +struct ByteStrVisitor; + +impl<'de> serde::de::Visitor<'de> for ByteStrVisitor { + type Value = ByteStr; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a UTF-8 string") + } + + #[inline] + fn visit_str(self, v: &str) -> Result + where E: serde::de::Error { + Ok(ByteStr::from(v)) + } + + #[inline] + fn visit_string(self, v: String) -> Result + where E: serde::de::Error { + Ok(ByteStr::from(v)) + } + + #[inline] + fn visit_bytes(self, v: &[u8]) -> Result + where E: serde::de::Error { + match str::from_utf8(v) { + Ok(s) => Ok(ByteStr::from(s)), + Err(e) => Err(E::custom(format_args!("invalid UTF-8: {e}"))), + } + } + + #[inline] + fn visit_byte_buf(self, v: Vec) -> Result + where E: serde::de::Error { + match String::from_utf8(v) { + Ok(s) => Ok(ByteStr::from(s)), + Err(e) => Err(E::custom(format_args!("invalid UTF-8: {}", e.utf8_error()))), + } + } + + #[inline] + fn visit_seq(self, mut seq: V) -> Result + where V: serde::de::SeqAccess<'de> { + use serde::de::Error as _; + let len = core::cmp::min(seq.size_hint().unwrap_or(0), 4096); + let mut bytes: Vec = Vec::with_capacity(len); + + while let Some(value) = seq.next_element()? { + bytes.push(value); + } + + match String::from_utf8(bytes) { + Ok(s) => Ok(ByteStr::from(s)), + Err(e) => Err(V::Error::custom(format_args!("invalid UTF-8: {}", e.utf8_error()))), + } + } +} + +impl<'de> serde::Deserialize<'de> for ByteStr { + #[inline] + fn deserialize(deserializer: D) -> Result + where D: serde::Deserializer<'de> { + deserializer.deserialize_string(ByteStrVisitor) + } +} diff --git a/patch/prost-0.14.1/src/encoding.rs b/patch/prost-0.14.1/src/encoding.rs new file mode 100644 index 0000000..0c196b0 --- /dev/null +++ b/patch/prost-0.14.1/src/encoding.rs @@ -0,0 +1,1471 @@ +//! Utility functions and types for encoding and decoding Protobuf types. +//! +//! This module contains the encoding and decoding primatives for Protobuf as described in +//! . +//! +//! This module is `pub`, but is only for prost internal use. The `prost-derive` crate needs access for its `Message` implementations. + +#![allow(clippy::implicit_hasher, clippy::ptr_arg)] + +use alloc::collections::BTreeMap; +#[cfg(not(feature = "std"))] +use alloc::{string::String, vec::Vec}; +use core::any::{Any, TypeId}; +use core::num::NonZeroU32; + +use ::bytes::{Buf, BufMut, Bytes}; + +use crate::{DecodeError, Message, ByteStr}; + +pub mod varint; +pub use varint::usize::{decode_varint, encode_varint, encoded_len_varint}; + +pub mod length_delimiter; +pub use length_delimiter::{ + decode_length_delimiter, encode_length_delimiter, length_delimiter_len, +}; + +pub mod wire_type; +pub use wire_type::{WireType, check_wire_type}; + +pub mod fixed_width; + +pub mod utf8; +pub use utf8::is_vaild_utf8; + +#[macro_export] +macro_rules! field_ { + (0) => {}; + ($n:expr) => { + unsafe { ::core::num::NonZeroU32::new_unchecked($n) } + }; +} + +define_typed_constants!( + #[allow(non_upper_case_globals)] + u32 => { + WireTypeBits = 3, + WireTypeMask = 7, + } + #[allow(non_upper_case_globals)] + pub NonZeroU32 => { + MaxFieldNumber = field_!((1 << 29) - 1), + FieldNumber1 = field_!(1), + FieldNumber2 = field_!(2), + } + #[allow(non_upper_case_globals)] + TypeId => { + __bytes__BytesMut = TypeId::of::<::bytes::BytesMut>(), + __alloc__vec__Vec_u8_ = TypeId::of::<::alloc::vec::Vec>(), + } +); + +/// Retrieves the `TypeId` of a potentially non-'static type `T`. +#[inline] +fn type_id_of() -> TypeId { + use ::core::marker::PhantomData; + + trait NonStaticAny { + fn get_type_id(&self) -> TypeId + where + Self: 'static; + } + + impl NonStaticAny for PhantomData { + fn get_type_id(&self) -> TypeId + where + Self: 'static, + { + TypeId::of::() + } + } + + let phantom_data = PhantomData::; + // Safety: `TypeId` is a function of the type structure, not its data or lifetime. + // Transmuting to satisfy the `'static` bound for this specific purpose is sound. + NonStaticAny::get_type_id(unsafe { + ::core::intrinsics::transmute_unchecked::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>( + &phantom_data, + ) + }) +} + +/// Performs a downcast from `&mut V` to `&mut T`, relying on a pre-computed type equality check. +/// +/// This is an optimized internal helper that avoids performing the type check itself. Its safety +/// depends on the caller upholding the `_eq` parameter contract. +#[inline(always)] +unsafe fn downcast_mut_prechecked(_val: &mut V, _eq: bool) -> Option<&mut T> { + if _eq { + // Safety: The caller guarantees via the `_eq` parameter that `V` is the same type as `T`. + // This makes the pointer type cast valid. + unsafe { Some(::core::mem::transmute(_val)) } + } else { + None + } +} + +/// Additional information passed to every decode/merge function. +/// +/// The context should be passed by value and can be freely cloned. When passing +/// to a function which is decoding a nested object, then use `enter_recursion`. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "no-recursion-limit", derive(Default))] +pub struct DecodeContext { + /// How many times we can recurse in the current decode stack before we hit + /// the recursion limit. + /// + /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be + /// customized. The recursion limit can be ignored by building the Prost + /// crate with the `no-recursion-limit` feature. + #[cfg(not(feature = "no-recursion-limit"))] + recurse_count: u32, +} + +#[cfg(not(feature = "no-recursion-limit"))] +impl Default for DecodeContext { + #[inline] + fn default() -> DecodeContext { + DecodeContext { + recurse_count: crate::RECURSION_LIMIT, + } + } +} + +impl DecodeContext { + /// Call this function before recursively decoding. + /// + /// There is no `exit` function since this function creates a new `DecodeContext` + /// to be used at the next level of recursion. Continue to use the old context + // at the previous level of recursion. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { + DecodeContext { + recurse_count: self.recurse_count - 1, + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { DecodeContext {} } + + /// Checks whether the recursion limit has been reached in the stack of + /// decodes described by the `DecodeContext` at `self.ctx`. + /// + /// Returns `Ok<()>` if it is ok to continue recursing. + /// Returns `Err` if the recursion limit has been reached. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { + if self.recurse_count == 0 { + Err(DecodeError::new("recursion limit reached")) + } else { + Ok(()) + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { Ok(()) } +} + +/// Encodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline] +pub fn encode_tag(number: NonZeroU32, wire_type: WireType, buf: &mut impl BufMut) { + debug_assert!(number <= MaxFieldNumber); + let tag = (number.get() << WireTypeBits) | wire_type as u32; + varint::encode_varint32(tag, buf); +} + +/// Decodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline(always)] +pub fn decode_tag(buf: &mut impl Buf) -> Result<(NonZeroU32, WireType), DecodeError> { + let tag = varint::decode_varint32(buf)?; + let (wire_type, number) = WireType::try_from_tag(tag)?; + if let Some(number) = NonZeroU32::new(number) { + Ok((number, wire_type)) + } else { + Err(DecodeError::new("invalid field number: 0")) + } +} + +/// Returns the width of an encoded Protobuf field tag with the given field number. +/// The returned width will be between 1 and 5 bytes (inclusive). +#[inline] +pub const fn tag_len(number: NonZeroU32) -> usize { varint::encoded_len_varint32(number.get() << WireTypeBits) } + +/// Helper function which abstracts reading a length delimiter prefix followed +/// by decoding values until the length of bytes is exhausted. +pub fn merge_loop( + value: &mut T, + buf: &mut B, + ctx: DecodeContext, + mut merge: M, +) -> Result<(), DecodeError> +where + M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>, + B: Buf, +{ + let len = decode_varint(buf)?; + let remaining = buf.remaining(); + if len > remaining { + return Err(DecodeError::new("buffer underflow")); + } + + let limit = remaining - len; + while buf.remaining() > limit { + merge(value, buf, ctx.clone())?; + } + + if buf.remaining() != limit { + return Err(DecodeError::new("delimited length exceeded")); + } + Ok(()) +} + +pub fn skip_field( + wire_type: WireType, + number: NonZeroU32, + buf: &mut impl Buf, + ctx: DecodeContext, +) -> Result<(), DecodeError> { + ctx.limit_reached()?; + let len = match wire_type { + WireType::Varint => decode_varint(buf).map(|_| 0)?, + WireType::ThirtyTwoBit => 4, + WireType::SixtyFourBit => 8, + WireType::LengthDelimited => decode_varint(buf)?, + WireType::StartGroup => loop { + let (inner_number, inner_wire_type) = decode_tag(buf)?; + match inner_wire_type { + WireType::EndGroup => { + if inner_number != number { + return Err(DecodeError::new("unexpected end group tag")); + } + break 0; + } + _ => skip_field(inner_wire_type, inner_number, buf, ctx.enter_recursion())?, + } + }, + WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")), + }; + + if len > buf.remaining() { + return Err(DecodeError::new("buffer underflow")); + } + + buf.advance(len); + Ok(()) +} + +/// Helper macro which emits an `encode_repeated` function for the type. +macro_rules! encode_repeated { + ($ty:ty) => { + pub fn encode_repeated(tag: NonZeroU32, values: &[$ty], buf: &mut impl BufMut) { + for value in values { + encode(tag, value, buf); + } + } + }; +} + +/// Helper macro which emits a `merge_repeated` function for the numeric type. +macro_rules! merge_repeated_numeric { + ($ty:ty, $wire_type:expr, $merge:ident) => { + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if wire_type == WireType::LengthDelimited { + // Packed. + merge_loop(values, buf, ctx, |values, buf, _ctx| { + let mut value = Default::default(); + $merge(&mut value, buf)?; + values.push(value); + Ok(()) + }) + } else { + // Unpacked. + check_wire_type($wire_type, wire_type)?; + let mut value = Default::default(); + $merge(&mut value, buf)?; + values.push(value); + Ok(()) + } + } + }; +} + +/// Macro which emits a module containing a set of encoding functions for a +/// variable width numeric type. +macro_rules! varint { + ($ty:ty, $proto_ty:ident) => { + pub mod $proto_ty { + use crate::encoding::varint::usize; + use crate::encoding::varint::$proto_ty::*; + use crate::encoding::wire_type::{WireType, check_wire_type}; + use crate::encoding::{ + __alloc__vec__Vec_u8_, __bytes__BytesMut, downcast_mut_prechecked, encode_tag, merge_loop, + tag_len, type_id_of, DecodeContext, + }; + use crate::error::DecodeError; + #[cfg(not(feature = "std"))] + use ::alloc::vec::Vec; + use ::bytes::{Buf, BufMut}; + use core::num::NonZeroU32; + + pub fn encode(number: NonZeroU32, value: &$ty, buf: &mut impl BufMut) { + encode_tag(number, WireType::Varint, buf); + encode_varint(*value, buf); + } + + pub fn merge( + wire_type: WireType, + value: &mut $ty, + buf: &mut impl Buf, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> { + check_wire_type(WireType::Varint, wire_type)?; + merge_unchecked(value, buf) + } + + #[inline(always)] + fn merge_unchecked(value: &mut $ty, buf: &mut impl Buf) -> Result<(), DecodeError> { + *value = decode_varint(buf)?; + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed(number: NonZeroU32, values: &[$ty], buf: &mut B) { + if values.is_empty() { + return; + } + + encode_tag(number, WireType::LengthDelimited, buf); + + let _id = type_id_of::(); + + if let Some(buf) = unsafe { downcast_mut_prechecked::<::bytes::BytesMut, B>(buf, _id == __bytes__BytesMut) } { + encode_packed_fast(values, buf); + } else if let Some(buf) = unsafe { downcast_mut_prechecked::, B>(buf, _id == __alloc__vec__Vec_u8_) } + { + encode_packed_fast(values, buf); + } else { + let len = values + .iter() + .map(|&value| encoded_len_varint(value)) + .sum::(); + usize::encode_varint(len, buf); + + for &value in values { + encode_varint(value, buf); + } + } + } + + merge_repeated_numeric!($ty, WireType::Varint, merge_unchecked); + + #[inline] + pub fn encoded_len(number: NonZeroU32, value: &$ty) -> usize { + tag_len(number) + encoded_len_varint(*value) + } + + #[inline] + pub fn encoded_len_repeated(number: NonZeroU32, values: &[$ty]) -> usize { + tag_len(number) * values.len() + + values + .iter() + .map(|&value| encoded_len_varint(value)) + .sum::() + } + + #[inline] + pub fn encoded_len_packed(number: NonZeroU32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = values + .iter() + .map(|&value| encoded_len_varint(value)) + .sum::(); + tag_len(number) + usize::encoded_len_varint(len) + len + } + } + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use crate::encoding::{ + test::{check_collection_type, check_type}, + $proto_ty::*, + }; + + proptest! { + #[test] + fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::Varint, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(value, tag, WireType::Varint, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + #[test] + fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::LengthDelimited, + encode_packed, merge_repeated, + encoded_len_packed)?; + } + } + } + } + }; +} +varint!(bool, bool); +varint!(i32, int32); +varint!(i64, int64); +varint!(u32, uint32); +varint!(u64, uint64); +varint!(i32, sint32); +varint!(i64, sint64); + +/// Macro which emits a module containing a set of encoding functions for a +/// fixed width numeric type. +macro_rules! fixed_size { + ($ty:ty, $proto_ty:ident) => { + pub mod $proto_ty { + use crate::encoding::fixed_width::$proto_ty::*; + use crate::encoding::varint::usize; + use crate::encoding::wire_type::{WireType, check_wire_type}; + use crate::encoding::{encode_tag, merge_loop, tag_len, DecodeContext}; + use crate::error::DecodeError; + #[cfg(not(feature = "std"))] + use ::alloc::vec::Vec; + use ::bytes::{Buf, BufMut}; + use core::num::NonZeroU32; + + pub fn encode(number: NonZeroU32, value: &$ty, buf: &mut impl BufMut) { + encode_tag(number, WIRE_TYPE, buf); + encode_fixed(*value, buf); + } + + pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut impl Buf, _ctx: DecodeContext) -> Result<(), DecodeError> { + check_wire_type(WIRE_TYPE, wire_type)?; + merge_unchecked(value, buf) + } + + #[inline(always)] + fn merge_unchecked(value: &mut $ty, buf: &mut impl Buf) -> Result<(), DecodeError> { + *value = decode_fixed(buf)?; + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed(number: NonZeroU32, values: &[$ty], buf: &mut impl BufMut) { + if values.is_empty() { + return; + } + + encode_tag(number, WireType::LengthDelimited, buf); + usize::encode_varint(values.len() * SIZE, buf); + + for &value in values { + encode_fixed(value, buf); + } + } + + merge_repeated_numeric!($ty, WIRE_TYPE, merge_unchecked); + + #[inline] + pub fn encoded_len(number: NonZeroU32, _: &$ty) -> usize { tag_len(number) + SIZE } + + #[inline] + pub fn encoded_len_repeated(number: NonZeroU32, values: &[$ty]) -> usize { + (tag_len(number) + SIZE) * values.len() + } + + #[inline] + pub fn encoded_len_packed(number: NonZeroU32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = SIZE * values.len(); + tag_len(number) + usize::encoded_len_varint(len) + len + } + } + } + }; +} +fixed_size!(f32, float); +fixed_size!(f64, double); +fixed_size!(u32, fixed32); +fixed_size!(u64, fixed64); +fixed_size!(i32, sfixed32); +fixed_size!(i64, sfixed64); + +/// Macro which emits encoding functions for a length-delimited type. +macro_rules! length_delimited { + ($ty:ty) => { + encode_repeated!($ty); + + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut value = Default::default(); + merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + + #[inline] + #[allow(clippy::ptr_arg)] + pub fn encoded_len(number: NonZeroU32, value: &$ty) -> usize { + tag_len(number) + encoded_len_varint(value.len()) + value.len() + } + + #[inline] + pub fn encoded_len_repeated(number: NonZeroU32, values: &[$ty]) -> usize { + tag_len(number) * values.len() + + values + .iter() + .map(|value| encoded_len_varint(value.len()) + value.len()) + .sum::() + } + }; +} + +mod sealed { + use super::{Buf, BufMut}; + + pub trait BytesAdapter: Default + Sized + 'static { + fn len(&self) -> usize; + + /// Replace contents of this buffer with the contents of another buffer. + fn replace_with(&mut self, buf: impl Buf); + + /// Appends this buffer to the (contents of) other buffer. + fn append_to(&self, buf: &mut impl BufMut); + + /// Merges a specified number of bytes from a buffer into `self`. + /// + /// This method encapsulates the type-specific optimal merge strategy. + fn merge_from_buf(&mut self, buf: &mut impl Buf, len: usize); + + fn clear(&mut self); + } + + pub trait StringAdapter: Default + Sized + 'static { + type Inner: super::BytesAdapter + AsRef<[u8]>; + + fn len(&self) -> usize; + + fn as_bytes(&self) -> &[u8]; + + unsafe fn as_mut(&mut self) -> &mut Self::Inner; + } +} + +pub trait StringAdapter: sealed::StringAdapter {} + +impl StringAdapter for ByteStr {} + +impl sealed::StringAdapter for ByteStr { + type Inner = Bytes; + + #[inline] + fn len(&self) -> usize { self.len() } + + #[inline] + fn as_bytes(&self) -> &[u8] { &self.as_bytes() } + + #[inline] + unsafe fn as_mut(&mut self) -> &mut Self::Inner { self.as_bytes_mut() } +} + +impl StringAdapter for String {} + +impl sealed::StringAdapter for String { + type Inner = Vec; + + #[inline] + fn len(&self) -> usize { self.len() } + + #[inline] + fn as_bytes(&self) -> &[u8] { self.as_bytes() } + + #[inline] + unsafe fn as_mut(&mut self) -> &mut Self::Inner { self.as_mut_vec() } +} + +pub mod string { + use super::*; + + pub fn encode(number: NonZeroU32, value: &impl StringAdapter, buf: &mut impl BufMut) { + encode_tag(number, WireType::LengthDelimited, buf); + encode_varint(value.len(), buf); + buf.put_slice(value.as_bytes()); + } + + pub fn merge( + wire_type: WireType, + value: &mut S, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + // ## Unsafety + // + // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 + // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the + // string is cleared, so as to avoid leaking a string field with invalid data. + // + // This implementation uses the `StringAdapter` trait which provides access to the underlying + // byte storage through `as_mut()`. This allows for efficient in-place modification while + // maintaining the invariant that the string must contain valid UTF-8. + // + // To ensure that invalid UTF-8 data is never exposed through the StringAdapter, even in the + // event of a panic in `bytes::merge` or in the buf implementation, a drop guard is used + // that will clear the underlying storage if the function exits abnormally. + + struct DropGuard<'a, S: StringAdapter>(&'a mut ::Inner); + impl Drop for DropGuard<'_, S> { + #[inline] + fn drop(&mut self) { super::sealed::BytesAdapter::clear(self.0) } + } + let drop_guard = unsafe { DropGuard::(value.as_mut()) }; + super::bytes::merge(wire_type, drop_guard.0, buf, ctx)?; + let s = drop_guard.0.as_ref(); + if super::utf8::utf8_valid_up_to(s) == s.len() { + // Success; do not clear the bytes. + ::core::mem::forget(drop_guard); + Ok(()) + } else { + Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )) + } + } + + length_delimited!(impl StringAdapter); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::{ + super::test::{check_collection_type, check_type}, + *, + }; + + proptest! { + #[test] + fn check(value: String, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub trait BytesAdapter: sealed::BytesAdapter {} + +impl BytesAdapter for Bytes {} + +impl sealed::BytesAdapter for Bytes { + #[inline] + fn len(&self) -> usize { ::bytes::Bytes::len(self) } + + #[inline] + fn replace_with(&mut self, mut buf: impl Buf) { *self = buf.copy_to_bytes(buf.remaining()); } + + #[inline] + fn append_to(&self, buf: &mut impl BufMut) { buf.put(self.clone()) } + + #[inline] + fn merge_from_buf(&mut self, buf: &mut impl Buf, len: usize) { + // Strategy for Bytes: use `copy_to_bytes` for potential zero-copy. + *self = buf.copy_to_bytes(len); + } + + #[inline] + fn clear(&mut self) { self.clear() } +} + +impl BytesAdapter for Vec {} + +impl sealed::BytesAdapter for Vec { + #[inline] + fn len(&self) -> usize { ::alloc::vec::Vec::len(self) } + + #[inline] + fn replace_with(&mut self, buf: impl Buf) { + self.clear(); + self.put(buf); + } + + #[inline] + fn append_to(&self, buf: &mut impl BufMut) { buf.put(self.as_slice()) } + + #[inline] + fn merge_from_buf(&mut self, buf: &mut impl Buf, len: usize) { + // Strategy for Vec: use `take` to ensure single-copy. + self.clear(); + self.put(buf.take(len)); + } + + #[inline] + fn clear(&mut self) { self.clear(); } +} + +pub mod bytes { + use super::*; + + pub fn encode(number: NonZeroU32, value: &impl BytesAdapter, buf: &mut impl BufMut) { + encode_tag(number, WireType::LengthDelimited, buf); + encode_varint(value.len(), buf); + value.append_to(buf); + } + + pub fn merge( + wire_type: WireType, + value: &mut impl BytesAdapter, + buf: &mut impl Buf, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let len = decode_varint(buf)?; + if len > buf.remaining() { + return Err(DecodeError::new( + "insufficient bytes for length-delimited field", + )); + } + + // Clear the existing value. This follows from the following rule in the encoding guide[1]: + // + // > Normally, an encoded message would never have more than one instance of a non-repeated + // > field. However, parsers are expected to handle the case in which they do. For numeric + // > types and strings, if the same field appears multiple times, the parser accepts the + // > last value it sees. + // + // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional + // + // This is intended for A and B both being Bytes so it is zero-copy. + // Some combinations of A and B types may cause a double-copy, + // in which case merge_one_copy() should be used instead. + value.merge_from_buf(buf, len); + Ok(()) + } + + length_delimited!(impl BytesAdapter); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::{ + super::test::{check_collection_type, check_type}, + *, + }; + + proptest! { + #[test] + fn check_vec(value: Vec, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type::, Vec>(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_bytes(value: Vec, tag in MIN_TAG..=MAX_TAG) { + let value = Bytes::from(value); + super::test::check_type::(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_repeated_vec(value: Vec>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + + #[test] + fn check_repeated_bytes(value: Vec>, tag in MIN_TAG..=MAX_TAG) { + let value = value.into_iter().map(Bytes::from).collect(); + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub mod message { + use super::*; + + pub fn encode(number: NonZeroU32, msg: &M, buf: &mut impl BufMut) + where + M: Message, + { + encode_tag(number, WireType::LengthDelimited, buf); + encode_varint(msg.encoded_len(), buf); + msg.encode_raw(buf); + } + + pub fn merge( + wire_type: WireType, + msg: &mut M, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + ctx.limit_reached()?; + merge_loop( + msg, + buf, + ctx.enter_recursion(), + |msg: &mut M, buf: &mut B, ctx| { + let (number, wire_type) = decode_tag(buf)?; + msg.merge_field(number, wire_type, buf, ctx) + }, + ) + } + + pub fn encode_repeated(number: NonZeroU32, messages: &[M], buf: &mut impl BufMut) + where + M: Message, + { + for msg in messages { + encode(number, msg, buf); + } + } + + pub fn merge_repeated( + wire_type: WireType, + messages: &mut Vec, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut msg = M::default(); + merge(WireType::LengthDelimited, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len(number: NonZeroU32, msg: &M) -> usize + where + M: Message, + { + let len = msg.encoded_len(); + tag_len(number) + encoded_len_varint(len) + len + } + + #[inline] + pub fn encoded_len_repeated(number: NonZeroU32, messages: &[M]) -> usize + where + M: Message, + { + tag_len(number) * messages.len() + + messages + .iter() + .map(Message::encoded_len) + .map(|len| len + encoded_len_varint(len)) + .sum::() + } +} + +pub mod group { + use super::*; + + pub fn encode(number: NonZeroU32, msg: &M, buf: &mut impl BufMut) + where + M: Message, + { + encode_tag(number, WireType::StartGroup, buf); + msg.encode_raw(buf); + encode_tag(number, WireType::EndGroup, buf); + } + + pub fn merge( + number: NonZeroU32, + wire_type: WireType, + msg: &mut M, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + { + check_wire_type(WireType::StartGroup, wire_type)?; + + ctx.limit_reached()?; + loop { + let (field_number, field_wire_type) = decode_tag(buf)?; + if field_wire_type == WireType::EndGroup { + if field_number != number { + return Err(DecodeError::new("unexpected end group tag")); + } + return Ok(()); + } + + M::merge_field(msg, field_number, field_wire_type, buf, ctx.enter_recursion())?; + } + } + + pub fn encode_repeated(number: NonZeroU32, messages: &[M], buf: &mut impl BufMut) + where + M: Message, + { + for msg in messages { + encode(number, msg, buf); + } + } + + pub fn merge_repeated( + number: NonZeroU32, + wire_type: WireType, + messages: &mut Vec, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + { + check_wire_type(WireType::StartGroup, wire_type)?; + let mut msg = M::default(); + merge(number, WireType::StartGroup, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len(number: NonZeroU32, msg: &M) -> usize + where + M: Message, + { + 2 * tag_len(number) + msg.encoded_len() + } + + #[inline] + pub fn encoded_len_repeated(number: NonZeroU32, messages: &[M]) -> usize + where + M: Message, + { + 2 * tag_len(number) * messages.len() + messages.iter().map(Message::encoded_len).sum::() + } +} + +/// Rust doesn't have a `Map` trait, so macros are currently the best way to be +/// generic over `HashMap` and `BTreeMap`. +macro_rules! map { + ($map_ty:ident) => { + use crate::encoding::*; + use core::hash::Hash; + + /// Generic protobuf map encode function. + pub fn encode( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + number: NonZeroU32, + values: &$map_ty, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + B: BufMut, + KE: Fn(NonZeroU32, &K, &mut B), + KL: Fn(NonZeroU32, &K) -> usize, + VE: Fn(NonZeroU32, &V, &mut B), + VL: Fn(NonZeroU32, &V) -> usize, + { + encode_with_default( + key_encode, + key_encoded_len, + val_encode, + val_encoded_len, + &V::default(), + number, + values, + buf, + ) + } + + /// Generic protobuf map merge function. + pub fn merge( + key_merge: KM, + val_merge: VM, + values: &mut $map_ty, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + V: Default, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) + } + + /// Generic protobuf map encode function. + pub fn encoded_len( + key_encoded_len: KL, + val_encoded_len: VL, + number: NonZeroU32, + values: &$map_ty, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + KL: Fn(NonZeroU32, &K) -> usize, + VL: Fn(NonZeroU32, &V) -> usize, + { + encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), number, values) + } + + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encode_with_default( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + val_default: &V, + number: NonZeroU32, + values: &$map_ty, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: PartialEq, + B: BufMut, + KE: Fn(NonZeroU32, &K, &mut B), + KL: Fn(NonZeroU32, &K) -> usize, + VE: Fn(NonZeroU32, &V, &mut B), + VL: Fn(NonZeroU32, &V) -> usize, + { + for (key, val) in values.iter() { + let skip_key = key == &K::default(); + let skip_val = val == val_default; + + let len = (if skip_key { 0 } else { key_encoded_len(FieldNumber1, key) }) + + (if skip_val { 0 } else { val_encoded_len(FieldNumber2, val) }); + + encode_tag(number, WireType::LengthDelimited, buf); + encode_varint(len, buf); + if !skip_key { + key_encode(FieldNumber1, key, buf); + } + if !skip_val { + val_encode(FieldNumber2, val, buf); + } + } + } + + /// Generic protobuf map merge function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn merge_with_default( + key_merge: KM, + val_merge: VM, + val_default: V, + values: &mut $map_ty, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + let mut key = Default::default(); + let mut val = val_default; + ctx.limit_reached()?; + merge_loop( + &mut (&mut key, &mut val), + buf, + ctx.enter_recursion(), + |&mut (ref mut key, ref mut val), buf, ctx| { + let (number, wire_type) = decode_tag(buf)?; + #[allow(non_upper_case_globals)] + match number { + FieldNumber1 => key_merge(wire_type, key, buf, ctx), + FieldNumber2 => val_merge(wire_type, val, buf, ctx), + _ => skip_field(wire_type, number, buf, ctx), + } + }, + )?; + values.insert(key, val); + + Ok(()) + } + + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encoded_len_with_default( + key_encoded_len: KL, + val_encoded_len: VL, + val_default: &V, + number: NonZeroU32, + values: &$map_ty, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: PartialEq, + KL: Fn(NonZeroU32, &K) -> usize, + VL: Fn(NonZeroU32, &V) -> usize, + { + tag_len(number) * values.len() + + values + .iter() + .map(|(key, val)| { + let len = (if key == &K::default() { + 0 + } else { + key_encoded_len(FieldNumber1, key) + }) + (if val == val_default { + 0 + } else { + val_encoded_len(FieldNumber2, val) + }); + encoded_len_varint(len) + len + }) + .sum::() + } + }; +} + +#[cfg(feature = "std")] +pub mod hash_map { + use std::collections::HashMap; + map!(HashMap); +} + +pub mod btree_map { + map!(BTreeMap); +} + +#[cfg(feature = "indexmap")] +pub mod index_map { + use indexmap::IndexMap; + map!(IndexMap); +} + +#[cfg(test)] +mod test { + #[cfg(not(feature = "std"))] + use alloc::string::ToString; + use core::{borrow::Borrow, fmt::Debug}; + + use ::bytes::BytesMut; + use proptest::{prelude::*, test_runner::TestCaseResult}; + + use super::*; + + pub fn check_type( + value: T, + number: NonZeroU32, + wire_type: WireType, + encode: fn(u32, &B, &mut BytesMut), + merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + encoded_len: fn(u32, &B) -> usize, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow, + B: ?Sized, + { + prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag)); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + if !buf.has_remaining() { + // Short circuit for empty packed values. + return Ok(()); + } + + let (decoded_number, decoded_wire_type) = + decode_tag(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + prop_assert_eq!( + tag, + decoded_number, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_number + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type, + ); + + match wire_type { + WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!( + "64bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!( + "32bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + _ => Ok(()), + }?; + + let mut roundtrip_value = T::default(); + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert!( + !buf.has_remaining(), + "expected buffer to be empty, remaining: {}", + buf.remaining() + ); + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + pub fn check_collection_type( + value: T, + number: NonZeroU32, + wire_type: WireType, + encode: E, + mut merge: M, + encoded_len: L, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow, + B: ?Sized, + E: FnOnce(u32, &B, &mut BytesMut), + M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + L: FnOnce(u32, &B) -> usize, + { + prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag)); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + let mut roundtrip_value = Default::default(); + while buf.has_remaining() { + let (decoded_number, decoded_wire_type) = + decode_tag(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert_eq!( + tag, + decoded_number, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_number + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type + ); + + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + } + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + #[test] + fn string_merge_invalid_utf8() { + let mut s = String::new(); + let buf = b"\x02\x80\x80"; + + let r = string::merge( + WireType::LengthDelimited, + &mut s, + &mut &buf[..], + DecodeContext::default(), + ); + r.expect_err("must be an error"); + assert!(s.is_empty()); + } + + /// This big bowl o' macro soup generates an encoding property test for each combination of map + /// type, scalar map key, and value type. + /// TODO: these tests take a long time to compile, can this be improved? + #[cfg(feature = "std")] + macro_rules! map_tests { + (keys: $keys:tt, + vals: $vals:tt) => { + mod hash_map { + map_tests!(@private HashMap, hash_map, $keys, $vals); + } + mod btree_map { + map_tests!(@private BTreeMap, btree_map, $keys, $vals); + } + }; + + (@private $map_type:ident, + $mod_name:ident, + [$(($key_ty:ty, $key_proto:ident)),*], + $vals:tt) => { + $( + mod $key_proto { + use std::collections::$map_type; + + use proptest::prelude::*; + + use crate::encoding::*; + use crate::encoding::test::check_collection_type; + + map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals); + } + )* + }; + + (@private $map_type:ident, + $mod_name:ident, + ($key_ty:ty, $key_proto:ident), + [$(($val_ty:ty, $val_proto:ident)),*]) => { + $( + proptest! { + #[test] + fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(values, tag, WireType::LengthDelimited, + |tag, values, buf| { + $mod_name::encode($key_proto::encode, + $key_proto::encoded_len, + $val_proto::encode, + $val_proto::encoded_len, + tag, + values, + buf) + }, + |wire_type, values, buf, ctx| { + check_wire_type(WireType::LengthDelimited, wire_type)?; + $mod_name::merge($key_proto::merge, + $val_proto::merge, + values, + buf, + ctx) + }, + |tag, values| { + $mod_name::encoded_len($key_proto::encoded_len, + $val_proto::encoded_len, + tag, + values) + })?; + } + } + )* + }; + } + + #[cfg(feature = "std")] + map_tests!(keys: [ + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string) + ], + vals: [ + (f32, float), + (f64, double), + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string), + (Vec, bytes) + ]); +} diff --git a/patch/prost-0.14.1/src/encoding/fixed_width.rs b/patch/prost-0.14.1/src/encoding/fixed_width.rs new file mode 100644 index 0000000..08a0d1d --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/fixed_width.rs @@ -0,0 +1,31 @@ +use ::bytes::{Buf, BufMut}; + +use super::wire_type::WireType; +use crate::error::DecodeError; +use alloc::string::ToString as _; + +macro_rules! fixed { + ($ty:ty, $proto_ty:ident, $wire_type:ident, $put:ident, $try_get:ident) => { + pub mod $proto_ty { + use super::*; + + pub const WIRE_TYPE: WireType = WireType::$wire_type; + pub const SIZE: usize = core::mem::size_of::<$ty>(); + + #[inline(always)] + pub fn encode_fixed(value: $ty, buf: &mut impl BufMut) { buf.$put(value); } + + #[inline(always)] + pub fn decode_fixed(buf: &mut impl Buf) -> Result<$ty, DecodeError> { + buf.$try_get().map_err(|e| DecodeError::new(e.to_string())) + } + } + }; +} + +fixed!(f32, float, ThirtyTwoBit, put_f32_le, try_get_f32_le); +fixed!(f64, double, SixtyFourBit, put_f64_le, try_get_f64_le); +fixed!(u32, fixed32, ThirtyTwoBit, put_u32_le, try_get_u32_le); +fixed!(u64, fixed64, SixtyFourBit, put_u64_le, try_get_u64_le); +fixed!(i32, sfixed32, ThirtyTwoBit, put_i32_le, try_get_i32_le); +fixed!(i64, sfixed64, SixtyFourBit, put_i64_le, try_get_i64_le); diff --git a/patch/prost-0.14.1/src/encoding/length_delimiter.rs b/patch/prost-0.14.1/src/encoding/length_delimiter.rs new file mode 100644 index 0000000..34e70b6 --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/length_delimiter.rs @@ -0,0 +1,46 @@ +pub use crate::{ + error::{DecodeError, EncodeError, UnknownEnumValue}, + message::Message, + // name::Name, +}; + +use ::bytes::{Buf, BufMut}; + +use crate::encoding::varint::usize::{decode_varint, encode_varint, encoded_len_varint}; + +/// Encodes a length delimiter to the buffer. +/// +/// See [Message.encode_length_delimited] for more info. +/// +/// An error will be returned if the buffer does not have sufficient capacity to encode the +/// delimiter. +pub fn encode_length_delimiter(length: usize, buf: &mut impl BufMut) -> Result<(), EncodeError> { + let required = encoded_len_varint(length); + let remaining = buf.remaining_mut(); + if required > remaining { + return Err(EncodeError::new(required, remaining)); + } + encode_varint(length, buf); + Ok(()) +} + +/// Returns the encoded length of a length delimiter. +/// +/// Applications may use this method to ensure sufficient buffer capacity before calling +/// `encode_length_delimiter`. The returned size will be between 1 and 10, inclusive. +#[inline] +pub fn length_delimiter_len(length: usize) -> usize { encoded_len_varint(length) } + +/// Decodes a length delimiter from the buffer. +/// +/// This method allows the length delimiter to be decoded independently of the message, when the +/// message is encoded with [Message.encode_length_delimited]. +/// +/// An error may be returned in two cases: +/// +/// * If the supplied buffer contains fewer than 10 bytes, then an error indicates that more +/// input is required to decode the full delimiter. +/// * If the supplied buffer contains 10 bytes or more, then the buffer contains an invalid +/// delimiter, and typically the buffer should be considered corrupt. +#[inline] +pub fn decode_length_delimiter(mut buf: impl Buf) -> Result { decode_varint(&mut buf) } diff --git a/patch/prost-0.14.1/src/encoding/utf8.rs b/patch/prost-0.14.1/src/encoding/utf8.rs new file mode 100644 index 0000000..747bc03 --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/utf8.rs @@ -0,0 +1,216 @@ +#![allow(unused)] + +mod ascii; + +#[cfg(any( + target_feature = "sse2", + all(target_endian = "little", target_arch = "aarch64"), + all(target_endian = "little", target_feature = "neon") +))] +mod simd_funcs; + +use ascii::validate_ascii; +use ::core::intrinsics::likely; + +#[inline(always)] +fn in_inclusive_range8(i: u8, start: u8, end: u8) -> bool { + i.wrapping_sub(start) <= (end - start) +} + +#[repr(align(64))] // Align to cache lines +pub struct Utf8Data { + pub table: [u8; 384], +} + +// BEGIN GENERATED CODE. PLEASE DO NOT EDIT. +// Instead, please regenerate using generate-encoding-data.py + +pub static UTF8_DATA: Utf8Data = Utf8Data { + table: [ + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 148, 148, 148, + 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 148, 164, 164, 164, 164, 164, + 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, 164, + 164, 164, 164, 164, 164, 164, 164, 164, 164, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, + 252, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 16, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 32, 8, 8, 64, 8, 8, 8, 128, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + ], +}; + +// END GENERATED CODE + +pub fn utf8_valid_up_to(src: &[u8]) -> usize { + let mut read = 0; + 'outer: loop { + let mut byte = { + let src_remaining = &src[read..]; + match validate_ascii(src_remaining) { + None => { + return src.len(); + } + Some((non_ascii, consumed)) => { + read += consumed; + non_ascii + } + } + }; + // Check for the longest sequence to avoid checking twice for the + // multi-byte sequences. This can't overflow with 64-bit address space, + // because full 64 bits aren't in use. In the 32-bit PAE case, for this + // to overflow would mean that the source slice would be so large that + // the address space of the process would not have space for any code. + // Therefore, the slice cannot be so long that this would overflow. + if likely(read + 4 <= src.len()) { + 'inner: loop { + // At this point, `byte` is not included in `read`, because we + // don't yet know that a) the UTF-8 sequence is valid and b) that there + // is output space if it is an astral sequence. + // Inspecting the lead byte directly is faster than what the + // std lib does! + if likely(in_inclusive_range8(byte, 0xC2, 0xDF)) { + // Two-byte + let second = unsafe { *(src.get_unchecked(read + 1)) }; + if !in_inclusive_range8(second, 0x80, 0xBF) { + break 'outer; + } + read += 2; + + // Next lead (manually inlined) + if likely(read + 4 <= src.len()) { + byte = unsafe { *(src.get_unchecked(read)) }; + if byte < 0x80 { + read += 1; + continue 'outer; + } + continue 'inner; + } + break 'inner; + } + if likely(byte < 0xF0) { + 'three: loop { + // Three-byte + let second = unsafe { *(src.get_unchecked(read + 1)) }; + let third = unsafe { *(src.get_unchecked(read + 2)) }; + if ((UTF8_DATA.table[usize::from(second)] + & unsafe { *(UTF8_DATA.table.get_unchecked(byte as usize + 0x80)) }) + | (third >> 6)) + != 2 + { + break 'outer; + } + read += 3; + + // Next lead (manually inlined) + if likely(read + 4 <= src.len()) { + byte = unsafe { *(src.get_unchecked(read)) }; + if in_inclusive_range8(byte, 0xE0, 0xEF) { + continue 'three; + } + if likely(byte < 0x80) { + read += 1; + continue 'outer; + } + continue 'inner; + } + break 'inner; + } + } + // Four-byte + let second = unsafe { *(src.get_unchecked(read + 1)) }; + let third = unsafe { *(src.get_unchecked(read + 2)) }; + let fourth = unsafe { *(src.get_unchecked(read + 3)) }; + if (u16::from( + UTF8_DATA.table[usize::from(second)] + & unsafe { *(UTF8_DATA.table.get_unchecked(byte as usize + 0x80)) }, + ) | u16::from(third >> 6) + | (u16::from(fourth & 0xC0) << 2)) + != 0x202 + { + break 'outer; + } + read += 4; + + // Next lead + if likely(read + 4 <= src.len()) { + byte = unsafe { *(src.get_unchecked(read)) }; + if byte < 0x80 { + read += 1; + continue 'outer; + } + continue 'inner; + } + break 'inner; + } + } + // We can't have a complete 4-byte sequence, but we could still have + // one to three shorter sequences. + 'tail: loop { + // >= is better for bound check elision than == + if read >= src.len() { + break 'outer; + } + byte = src[read]; + // At this point, `byte` is not included in `read`, because we + // don't yet know that a) the UTF-8 sequence is valid and b) that there + // is output space if it is an astral sequence. + // Inspecting the lead byte directly is faster than what the + // std lib does! + if byte < 0x80 { + read += 1; + continue 'tail; + } + if in_inclusive_range8(byte, 0xC2, 0xDF) { + // Two-byte + let new_read = read + 2; + if new_read > src.len() { + break 'outer; + } + let second = src[read + 1]; + if !in_inclusive_range8(second, 0x80, 0xBF) { + break 'outer; + } + read += 2; + continue 'tail; + } + // We need to exclude valid four byte lead bytes, because + // `UTF8_DATA.second_mask` covers + if byte < 0xF0 { + // Three-byte + let new_read = read + 3; + if new_read > src.len() { + break 'outer; + } + let second = src[read + 1]; + let third = src[read + 2]; + if ((UTF8_DATA.table[usize::from(second)] + & unsafe { *(UTF8_DATA.table.get_unchecked(byte as usize + 0x80)) }) + | (third >> 6)) + != 2 + { + break 'outer; + } + read += 3; + // `'tail` handles sequences shorter than 4, so + // there can't be another sequence after this one. + break 'outer; + } + break 'outer; + } + } + read +} + +#[inline(always)] +pub fn is_vaild_utf8(v:&[u8]) -> bool { utf8_valid_up_to(v) == v.len() } \ No newline at end of file diff --git a/patch/prost-0.14.1/src/encoding/utf8/ascii.rs b/patch/prost-0.14.1/src/encoding/utf8/ascii.rs new file mode 100644 index 0000000..3a6b959 --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/utf8/ascii.rs @@ -0,0 +1,1847 @@ +// Copyright Mozilla Foundation. See the COPYRIGHT +// file at the top-level directory of this distribution. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// It's assumed that in due course Rust will have explicit SIMD but will not +// be good at run-time selection of SIMD vs. no-SIMD. In such a future, +// x86_64 will always use SSE2 and 32-bit x86 will use SSE2 when compiled with +// a Mozilla-shipped rustc. SIMD support and especially detection on ARM is a +// mess. Under the circumstances, it seems to make sense to optimize the ALU +// case for ARMv7 rather than x86. Annoyingly, I was unable to get useful +// numbers of the actual ARMv7 CPU I have access to, because (thermal?) +// throttling kept interfering. Since Raspberry Pi 3 (ARMv8 core but running +// ARMv7 code) produced reproducible performance numbers, that's the ARM +// computer that this code ended up being optimized for in the ALU case. +// Less popular CPU architectures simply get the approach that was chosen based +// on Raspberry Pi 3 measurements. The UTF-16 and UTF-8 ALU cases take +// different approaches based on benchmarking on Raspberry Pi 3. + +#[cfg(any( + target_feature = "sse2", + all(target_endian = "little", target_arch = "aarch64"), + all(target_endian = "little", target_feature = "neon") +))] +use super::simd_funcs::*; + +#[allow(unused_imports)] +use ::core::intrinsics::unlikely; +#[allow(unused_imports)] +use ::core::intrinsics::likely; + +use cfg_if::cfg_if; + +// Safety invariants for masks: data & mask = 0 for valid ASCII or basic latin utf-16 + +// `as` truncates, so works on 32-bit, too. +#[allow(dead_code)] +pub const ASCII_MASK: usize = 0x8080_8080_8080_8080u64 as usize; + +// `as` truncates, so works on 32-bit, too. +#[allow(dead_code)] +pub const BASIC_LATIN_MASK: usize = 0xFF80_FF80_FF80_FF80u64 as usize; + +#[allow(unused_macros)] +macro_rules! ascii_naive { + ($name:ident, $src_unit:ty, $dst_unit:ty) => { + /// Safety: src and dst must have len_unit elements and be aligned + /// Safety-usable invariant: will return Some() when it fails + /// to convert. The first value will be a u8 that is > 127. + #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + // Yes, manually omitting the bound check here matters + // a lot for perf. + for i in 0..len { + // Safety: len invariant used here + let code_unit = *(src.add(i)); + // Safety: Upholds safety-usable invariant here + if code_unit > 127 { + return Some((code_unit, i)); + } + // Safety: len invariant used here + *(dst.add(i)) = code_unit as $dst_unit; + } + return None; + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_alu { + ($name:ident, + // safety invariant: src/dst MUST be u8 + $src_unit:ty, + $dst_unit:ty, + // Safety invariant: stride_fn must consume and produce two usizes, and return the index of the first non-ascii when it fails + $stride_fn:ident) => { + /// Safety: src and dst must have len elements, src is valid for read, dst is valid for + /// write + /// Safety-usable invariant: will return Some() when it fails + /// to convert. The first value will be a u8 that is > 127. + #[cfg_attr(feature = "cargo-clippy", allow(never_loop, cast_ptr_alignment))] + #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + let mut offset = 0usize; + // This loop is only broken out of as a `goto` forward + loop { + // Safety: until_alignment becomes the number of bytes we need to munch until we are aligned to usize + let mut until_alignment = { + // Check if the other unit aligns if we move the narrower unit + // to alignment. + // if ::core::mem::size_of::<$src_unit>() == ::core::mem::size_of::<$dst_unit>() { + // ascii_to_ascii + let src_alignment = (src as usize) & ALU_ALIGNMENT_MASK; + let dst_alignment = (dst as usize) & ALU_ALIGNMENT_MASK; + if src_alignment != dst_alignment { + // Safety: bails early and ends up in the naïve branch where usize-alignment doesn't matter + break; + } + (ALU_ALIGNMENT - src_alignment) & ALU_ALIGNMENT_MASK + // } else if ::core::mem::size_of::<$src_unit>() < ::core::mem::size_of::<$dst_unit>() { + // ascii_to_basic_latin + // let src_until_alignment = (ALIGNMENT - ((src as usize) & ALIGNMENT_MASK)) & ALIGNMENT_MASK; + // if (dst.add(src_until_alignment) as usize) & ALIGNMENT_MASK != 0 { + // break; + // } + // src_until_alignment + // } else { + // basic_latin_to_ascii + // let dst_until_alignment = (ALIGNMENT - ((dst as usize) & ALIGNMENT_MASK)) & ALIGNMENT_MASK; + // if (src.add(dst_until_alignment) as usize) & ALIGNMENT_MASK != 0 { + // break; + // } + // dst_until_alignment + // } + }; + if until_alignment + ALU_STRIDE_SIZE <= len { + // Moving pointers to alignment seems to be a pessimization on + // x86_64 for operations that have UTF-16 as the internal + // Unicode representation. However, since it seems to be a win + // on ARM (tested ARMv7 code running on ARMv8 [rpi3]), except + // mixed results when encoding from UTF-16 and since x86 and + // x86_64 should be using SSE2 in due course, keeping the move + // to alignment here. It would be good to test on more ARM CPUs + // and on real MIPS and POWER hardware. + // + // Safety: This is the naïve code once again, for `until_alignment` bytes + while until_alignment != 0 { + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety: Upholds safety-usable invariant here + return Some((code_unit, offset)); + } + *(dst.add(offset)) = code_unit as $dst_unit; + // Safety: offset is the number of bytes copied so far + offset += 1; + until_alignment -= 1; + } + let len_minus_stride = len - ALU_STRIDE_SIZE; + loop { + // Safety: num_ascii is known to be a byte index of a non-ascii byte due to stride_fn's invariant + if let Some(num_ascii) = $stride_fn( + // Safety: These are known to be valid and aligned since we have at + // least ALU_STRIDE_SIZE data in these buffers, and offset is the + // number of elements copied so far, which according to the + // until_alignment calculation above will cause both src and dst to be + // aligned to usize after this add + src.add(offset) as *const usize, + dst.add(offset) as *mut usize, + ) { + offset += num_ascii; + // Safety: Upholds safety-usable invariant here by indexing into non-ascii byte + return Some((*(src.add(offset)), offset)); + } + // Safety: offset continues to be the number of bytes copied so far, and + // maintains usize alignment for the next loop iteration + offset += ALU_STRIDE_SIZE; + // Safety: This is `offset > len - stride. This loop will continue as long as + // `offset <= len - stride`, which means there are `stride` bytes to still be read. + if offset > len_minus_stride { + break; + } + } + } + break; + } + + // Safety: This is the naïve code, same as ascii_naive, and has no requirements + // other than src/dst being valid for the the right lens + while offset < len { + // Safety: len invariant used here + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety: Upholds safety-usable invariant here + return Some((code_unit, offset)); + } + // Safety: len invariant used here + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! basic_latin_alu { + ($name:ident, + // safety invariant: use u8 for src/dest for ascii, and u16 for basic_latin + $src_unit:ty, + $dst_unit:ty, + // safety invariant: stride function must munch ALU_STRIDE_SIZE*size(src_unit) bytes off of src and + // write ALU_STRIDE_SIZE*size(dst_unit) bytes to dst + $stride_fn:ident) => { + /// Safety: src and dst must have len elements, src is valid for read, dst is valid for + /// write + /// Safety-usable invariant: will return Some() when it fails + /// to convert. The first value will be a u8 that is > 127. + #[cfg_attr( + feature = "cargo-clippy", + allow(never_loop, cast_ptr_alignment, cast_lossless) + )] + #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + let mut offset = 0usize; + // This loop is only broken out of as a `goto` forward + loop { + // Safety: until_alignment becomes the number of bytes we need to munch from src/dest until we are aligned to usize + // We ensure basic-latin has the same alignment as ascii, starting with ascii since it is smaller. + let mut until_alignment = { + // Check if the other unit aligns if we move the narrower unit + // to alignment. + // if ::core::mem::size_of::<$src_unit>() == ::core::mem::size_of::<$dst_unit>() { + // ascii_to_ascii + // let src_alignment = (src as usize) & ALIGNMENT_MASK; + // let dst_alignment = (dst as usize) & ALIGNMENT_MASK; + // if src_alignment != dst_alignment { + // break; + // } + // (ALIGNMENT - src_alignment) & ALIGNMENT_MASK + // } else + if ::core::mem::size_of::<$src_unit>() < ::core::mem::size_of::<$dst_unit>() { + // ascii_to_basic_latin + let src_until_alignment = (ALU_ALIGNMENT + - ((src as usize) & ALU_ALIGNMENT_MASK)) + & ALU_ALIGNMENT_MASK; + if (dst.wrapping_add(src_until_alignment) as usize) & ALU_ALIGNMENT_MASK + != 0 + { + break; + } + src_until_alignment + } else { + // basic_latin_to_ascii + let dst_until_alignment = (ALU_ALIGNMENT + - ((dst as usize) & ALU_ALIGNMENT_MASK)) + & ALU_ALIGNMENT_MASK; + if (src.wrapping_add(dst_until_alignment) as usize) & ALU_ALIGNMENT_MASK + != 0 + { + break; + } + dst_until_alignment + } + }; + if until_alignment + ALU_STRIDE_SIZE <= len { + // Moving pointers to alignment seems to be a pessimization on + // x86_64 for operations that have UTF-16 as the internal + // Unicode representation. However, since it seems to be a win + // on ARM (tested ARMv7 code running on ARMv8 [rpi3]), except + // mixed results when encoding from UTF-16 and since x86 and + // x86_64 should be using SSE2 in due course, keeping the move + // to alignment here. It would be good to test on more ARM CPUs + // and on real MIPS and POWER hardware. + // + // Safety: This is the naïve code once again, for `until_alignment` bytes + while until_alignment != 0 { + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety: Upholds safety-usable invariant here + return Some((code_unit, offset)); + } + *(dst.add(offset)) = code_unit as $dst_unit; + // Safety: offset is the number of bytes copied so far + offset += 1; + until_alignment -= 1; + } + let len_minus_stride = len - ALU_STRIDE_SIZE; + loop { + if !$stride_fn( + // Safety: These are known to be valid and aligned since we have at + // least ALU_STRIDE_SIZE data in these buffers, and offset is the + // number of elements copied so far, which according to the + // until_alignment calculation above will cause both src and dst to be + // aligned to usize after this add + src.add(offset) as *const usize, + dst.add(offset) as *mut usize, + ) { + break; + } + // Safety: offset continues to be the number of bytes copied so far, and + // maintains usize alignment for the next loop iteration + offset += ALU_STRIDE_SIZE; + // Safety: This is `offset > len - stride. This loop will continue as long as + // `offset <= len - stride`, which means there are `stride` bytes to still be read. + if offset > len_minus_stride { + break; + } + } + } + break; + } + // Safety: This is the naïve code once again, for leftover bytes + while offset < len { + // Safety: len invariant used here + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety: Upholds safety-usable invariant here + return Some((code_unit, offset)); + } + // Safety: len invariant used here + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! latin1_alu { + // safety invariant: stride function must munch ALU_STRIDE_SIZE*size(src_unit) bytes off of src and + // write ALU_STRIDE_SIZE*size(dst_unit) bytes to dst + ($name:ident, $src_unit:ty, $dst_unit:ty, $stride_fn:ident) => { + /// Safety: src and dst must have len elements, src is valid for read, dst is valid for + /// write + #[cfg_attr( + feature = "cargo-clippy", + allow(never_loop, cast_ptr_alignment, cast_lossless) + )] + #[inline(always)] + pub unsafe fn $name(src: *const $src_unit, dst: *mut $dst_unit, len: usize) { + let mut offset = 0usize; + // This loop is only broken out of as a `goto` forward + loop { + // Safety: until_alignment becomes the number of bytes we need to munch from src/dest until we are aligned to usize + // We ensure the UTF-16 side has the same alignment as the Latin-1 side, starting with Latin-1 since it is smaller. + let mut until_alignment = { + if ::core::mem::size_of::<$src_unit>() < ::core::mem::size_of::<$dst_unit>() { + // unpack + let src_until_alignment = (ALU_ALIGNMENT + - ((src as usize) & ALU_ALIGNMENT_MASK)) + & ALU_ALIGNMENT_MASK; + if (dst.wrapping_add(src_until_alignment) as usize) & ALU_ALIGNMENT_MASK + != 0 + { + break; + } + src_until_alignment + } else { + // pack + let dst_until_alignment = (ALU_ALIGNMENT + - ((dst as usize) & ALU_ALIGNMENT_MASK)) + & ALU_ALIGNMENT_MASK; + if (src.wrapping_add(dst_until_alignment) as usize) & ALU_ALIGNMENT_MASK + != 0 + { + break; + } + dst_until_alignment + } + }; + if until_alignment + ALU_STRIDE_SIZE <= len { + // Safety: This is the naïve code once again, for `until_alignment` bytes + while until_alignment != 0 { + let code_unit = *(src.add(offset)); + *(dst.add(offset)) = code_unit as $dst_unit; + // Safety: offset is the number of bytes copied so far + offset += 1; + until_alignment -= 1; + } + let len_minus_stride = len - ALU_STRIDE_SIZE; + loop { + $stride_fn( + // Safety: These are known to be valid and aligned since we have at + // least ALU_STRIDE_SIZE data in these buffers, and offset is the + // number of elements copied so far, which according to the + // until_alignment calculation above will cause both src and dst to be + // aligned to usize after this add + src.add(offset) as *const usize, + dst.add(offset) as *mut usize, + ); + // Safety: offset continues to be the number of bytes copied so far, and + // maintains usize alignment for the next loop iteration + offset += ALU_STRIDE_SIZE; + // Safety: This is `offset > len - stride. This loop will continue as long as + // `offset <= len - stride`, which means there are `stride` bytes to still be read. + if offset > len_minus_stride { + break; + } + } + } + break; + } + // Safety: This is the naïve code once again, for leftover bytes + while offset < len { + // Safety: len invariant used here + let code_unit = *(src.add(offset)); + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_simd_check_align { + ( + $name:ident, + $src_unit:ty, + $dst_unit:ty, + // Safety: This function must require aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_both_aligned:ident, + // Safety: This function must require aligned/unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_src_aligned:ident, + // Safety: This function must require unaligned/aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_dst_aligned:ident, + // Safety: This function must require unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_neither_aligned:ident + ) => { + /// Safety: src/dst must be valid for reads/writes of `len` elements of their units. + /// + /// Safety-usable invariant: will return Some() when it encounters non-ASCII, with the first element in the Some being + /// guaranteed to be non-ASCII (> 127), and the second being the offset where it is found + #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `SIMD_STRIDE_SIZE` elements. + if SIMD_STRIDE_SIZE <= len { + let len_minus_stride = len - SIMD_STRIDE_SIZE; + // XXX Should we first process one stride unconditionally as unaligned to + // avoid the cost of the branchiness below if the first stride fails anyway? + // XXX Should we just use unaligned SSE2 access unconditionally? It seems that + // on Haswell, it would make sense to just use unaligned and not bother + // checking. Need to benchmark older architectures before deciding. + let dst_masked = (dst as usize) & SIMD_ALIGNMENT_MASK; + // Safety: checking whether src is aligned + if ((src as usize) & SIMD_ALIGNMENT_MASK) == 0 { + // Safety: Checking whether dst is aligned + if dst_masked == 0 { + loop { + // Safety: We're valid to read/write SIMD_STRIDE_SIZE elements and have the appropriate alignments + if !$stride_both_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE` which means we always have at least `SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } else { + loop { + // Safety: We're valid to read/write SIMD_STRIDE_SIZE elements and have the appropriate alignments + if !$stride_src_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE` which means we always have at least `SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + } else { + if dst_masked == 0 { + loop { + // Safety: We're valid to read/write SIMD_STRIDE_SIZE elements and have the appropriate alignments + if !$stride_dst_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE` which means we always have at least `SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } else { + loop { + // Safety: We're valid to read/write SIMD_STRIDE_SIZE elements and have the appropriate alignments + if !$stride_neither_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE` which means we always have at least `SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + } + } + while offset < len { + // Safety: uses len invariant here and below + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety: upholds safety-usable invariant + return Some((code_unit, offset)); + } + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_simd_check_align_unrolled { + ( + $name:ident, + $src_unit:ty, + $dst_unit:ty, + // Safety: This function must require aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_both_aligned:ident, + // Safety: This function must require aligned/unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_src_aligned:ident, + // Safety: This function must require unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_neither_aligned:ident, + // Safety: This function must require aligned src/dest that are valid for reading/writing 2*SIMD_STRIDE_SIZE src_unit/dst_unit + $double_stride_both_aligned:ident, + // Safety: This function must require aligned/unaligned src/dest that are valid for reading/writing 2*SIMD_STRIDE_SIZE src_unit/dst_unit + $double_stride_src_aligned:ident + ) => { + /// Safety: src/dst must be valid for reads/writes of `len` elements of their units. + /// + /// Safety-usable invariant: will return Some() when it encounters non-ASCII, with the first element in the Some being + /// guaranteed to be non-ASCII (> 127), and the second being the offset where it is found #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + let unit_size = ::core::mem::size_of::<$src_unit>(); + let mut offset = 0usize; + // This loop is only broken out of as a goto forward without + // actually looping + 'outer: loop { + // Safety: if this check succeeds we're valid for reading/writing at least `SIMD_STRIDE_SIZE` elements. + if SIMD_STRIDE_SIZE <= len { + // First, process one unaligned + // Safety: this is safe to call since we're valid for this read/write + if !$stride_neither_aligned(src, dst) { + break 'outer; + } + offset = SIMD_STRIDE_SIZE; + + // We have now seen 16 ASCII bytes. Let's guess that + // there will be enough more to justify more expense + // in the case of non-ASCII. + // Use aligned reads for the sake of old microachitectures. + // + // Safety: this correctly calculates the number of src_units that need to be read before the remaining list is aligned. + // This is less that SIMD_ALIGNMENT, which is also SIMD_STRIDE_SIZE (as documented) + let until_alignment = ((SIMD_ALIGNMENT + - ((src.add(offset) as usize) & SIMD_ALIGNMENT_MASK)) + & SIMD_ALIGNMENT_MASK) + / unit_size; + // Safety: This addition won't overflow, because even in the 32-bit PAE case the + // address space holds enough code that the slice length can't be that + // close to address space size. + // offset now equals SIMD_STRIDE_SIZE, hence times 3 below. + // + // Safety: if this check succeeds we're valid for reading/writing at least `2 * SIMD_STRIDE_SIZE` elements plus `until_alignment`. + // The extra SIMD_STRIDE_SIZE in the condition is because `offset` is already `SIMD_STRIDE_SIZE`. + if until_alignment + (SIMD_STRIDE_SIZE * 3) <= len { + if until_alignment != 0 { + // Safety: this is safe to call since we're valid for this read/write (and more), and don't care about alignment + // This will copy over bytes that get decoded twice since it's not incrementing `offset` by SIMD_STRIDE_SIZE. This is fine. + if !$stride_neither_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += until_alignment; + } + // Safety: At this point we're valid for reading/writing 2*SIMD_STRIDE_SIZE elements + // Safety: Now `offset` is aligned for `src` + let len_minus_stride_times_two = len - (SIMD_STRIDE_SIZE * 2); + // Safety: This is whether dst is aligned + let dst_masked = (dst.add(offset) as usize) & SIMD_ALIGNMENT_MASK; + if dst_masked == 0 { + loop { + // Safety: both are aligned, we can call the aligned function. We're valid for reading/writing double stride from the initial condition + // and the loop break condition below + if let Some(advance) = + $double_stride_both_aligned(src.add(offset), dst.add(offset)) + { + offset += advance; + let code_unit = *(src.add(offset)); + // Safety: uses safety-usable invariant on ascii_to_ascii_simd_double_stride to return + // guaranteed non-ascii + return Some((code_unit, offset)); + } + offset += SIMD_STRIDE_SIZE * 2; + // Safety: This is `offset > len - 2 * SIMD_STRIDE_SIZE` which means we always have at least `2 * SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride_times_two { + break; + } + } + // Safety: We're valid for reading/writing one more, and can still assume alignment + if offset + SIMD_STRIDE_SIZE <= len { + if !$stride_both_aligned(src.add(offset), dst.add(offset)) { + break 'outer; + } + offset += SIMD_STRIDE_SIZE; + } + } else { + loop { + // Safety: only src is aligned here. We're valid for reading/writing double stride from the initial condition + // and the loop break condition below + if let Some(advance) = + $double_stride_src_aligned(src.add(offset), dst.add(offset)) + { + offset += advance; + let code_unit = *(src.add(offset)); + // Safety: uses safety-usable invariant on ascii_to_ascii_simd_double_stride to return + // guaranteed non-ascii + return Some((code_unit, offset)); + } + offset += SIMD_STRIDE_SIZE * 2; + // Safety: This is `offset > len - 2 * SIMD_STRIDE_SIZE` which means we always have at least `2 * SIMD_STRIDE_SIZE` elements to munch next time. + + if offset > len_minus_stride_times_two { + break; + } + } + // Safety: We're valid for reading/writing one more, and can still assume alignment + if offset + SIMD_STRIDE_SIZE <= len { + if !$stride_src_aligned(src.add(offset), dst.add(offset)) { + break 'outer; + } + offset += SIMD_STRIDE_SIZE; + } + } + } else { + // At most two iterations, so unroll + if offset + SIMD_STRIDE_SIZE <= len { + // Safety: The check above ensures we're allowed to read/write this, and we don't use alignment + if !$stride_neither_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + if offset + SIMD_STRIDE_SIZE <= len { + // Safety: The check above ensures we're allowed to read/write this, and we don't use alignment + if !$stride_neither_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + } + } + } + } + break 'outer; + } + while offset < len { + // Safety: relies straightforwardly on the `len` invariant + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety-usable invariant upheld here + return Some((code_unit, offset)); + } + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! latin1_simd_check_align { + ( + $name:ident, + $src_unit:ty, + $dst_unit:ty, + // Safety: This function must require aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_both_aligned:ident, + // Safety: This function must require aligned/unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_src_aligned:ident, + // Safety: This function must require unaligned/aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_dst_aligned:ident, + // Safety: This function must require unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_neither_aligned:ident + + ) => { + /// Safety: src/dst must be valid for reads/writes of `len` elements of their units. + #[inline(always)] + pub unsafe fn $name(src: *const $src_unit, dst: *mut $dst_unit, len: usize) { + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `SIMD_STRIDE_SIZE` elements. + if SIMD_STRIDE_SIZE <= len { + let len_minus_stride = len - SIMD_STRIDE_SIZE; + // Whether dst is aligned + let dst_masked = (dst as usize) & SIMD_ALIGNMENT_MASK; + // Whether src is aligned + if ((src as usize) & SIMD_ALIGNMENT_MASK) == 0 { + if dst_masked == 0 { + loop { + // Safety: Both were aligned, we can use the aligned function + $stride_both_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE`, which means in the next iteration we're valid for + // reading/writing at least SIMD_STRIDE_SIZE elements. + if offset > len_minus_stride { + break; + } + } + } else { + loop { + // Safety: src was aligned, dst was not + $stride_src_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE`, which means in the next iteration we're valid for + // reading/writing at least SIMD_STRIDE_SIZE elements. + if offset > len_minus_stride { + break; + } + } + } + } else { + if dst_masked == 0 { + loop { + // Safety: src was aligned, dst was not + $stride_dst_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE`, which means in the next iteration we're valid for + // reading/writing at least SIMD_STRIDE_SIZE elements. + if offset > len_minus_stride { + break; + } + } + } else { + loop { + // Safety: Neither were aligned + $stride_neither_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - SIMD_STRIDE_SIZE`, which means in the next iteration we're valid for + // reading/writing at least SIMD_STRIDE_SIZE elements. + if offset > len_minus_stride { + break; + } + } + } + } + } + while offset < len { + // Safety: relies straightforwardly on the `len` invariant + let code_unit = *(src.add(offset)); + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! latin1_simd_check_align_unrolled { + ( + $name:ident, + $src_unit:ty, + $dst_unit:ty, + // Safety: This function must require aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_both_aligned:ident, + // Safety: This function must require aligned/unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_src_aligned:ident, + // Safety: This function must require unaligned/aligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_dst_aligned:ident, + // Safety: This function must require unaligned src/dest that are valid for reading/writing SIMD_STRIDE_SIZE src_unit/dst_unit + $stride_neither_aligned:ident + ) => { + /// Safety: src/dst must be valid for reads/writes of `len` elements of their units. + #[inline(always)] + pub unsafe fn $name(src: *const $src_unit, dst: *mut $dst_unit, len: usize) { + let unit_size = ::core::mem::size_of::<$src_unit>(); + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `SIMD_STRIDE_SIZE` elements. + if SIMD_STRIDE_SIZE <= len { + // Safety: this correctly calculates the number of src_units that need to be read before the remaining list is aligned. + // This is by definition less than SIMD_STRIDE_SIZE. + let mut until_alignment = ((SIMD_STRIDE_SIZE + - ((src as usize) & SIMD_ALIGNMENT_MASK)) + & SIMD_ALIGNMENT_MASK) + / unit_size; + while until_alignment != 0 { + // Safety: This is a straightforward copy, since until_alignment is < SIMD_STRIDE_SIZE < len, this is in-bounds + *(dst.add(offset)) = *(src.add(offset)) as $dst_unit; + offset += 1; + until_alignment -= 1; + } + // Safety: here offset will be `until_alignment`, i.e. enough to align `src`. + let len_minus_stride = len - SIMD_STRIDE_SIZE; + // Safety: if this check succeeds we're valid for reading/writing at least `2 * SIMD_STRIDE_SIZE` elements. + if offset + SIMD_STRIDE_SIZE * 2 <= len { + let len_minus_stride_times_two = len_minus_stride - SIMD_STRIDE_SIZE; + // Safety: at this point src is known to be aligned at offset, dst is not. + if (dst.add(offset) as usize) & SIMD_ALIGNMENT_MASK == 0 { + loop { + // Safety: We checked alignment of dst above, we can use the alignment functions. We're allowed to read/write 2*SIMD_STRIDE_SIZE elements, which we do. + $stride_both_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + $stride_both_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - 2 * SIMD_STRIDE_SIZE` which means we always have at least `2 * SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride_times_two { + break; + } + } + } else { + loop { + // Safety: we ensured alignment of src already. + $stride_src_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + $stride_src_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // Safety: This is `offset > len - 2 * SIMD_STRIDE_SIZE` which means we always have at least `2 * SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride_times_two { + break; + } + } + } + } + // Safety: This is `offset > len - SIMD_STRIDE_SIZE` which means we are valid to munch SIMD_STRIDE_SIZE more elements, which we do + if offset < len_minus_stride { + $stride_src_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + } + } + while offset < len { + // Safety: uses len invariant here and below + let code_unit = *(src.add(offset)); + // On x86_64, this loop autovectorizes but in the pack + // case there are instructions whose purpose is to make sure + // each u16 in the vector is truncated before packing. However, + // since we don't care about saturating behavior of SSE2 packing + // when the input isn't Latin1, those instructions are useless. + // Unfortunately, using the `assume` intrinsic to lie to the + // optimizer doesn't make LLVM omit the trunctation that we + // don't need. Possibly this loop could be manually optimized + // to do the sort of thing that LLVM does but without the + // ANDing the read vectors of u16 with a constant that discards + // the high half of each u16. As far as I can tell, the + // optimization assumes that doing a SIMD read past the end of + // the array is OK. + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_simd_unalign { + // Safety: stride_neither_aligned must be a function that requires src/dest be valid for unaligned reads/writes for SIMD_STRIDE_SIZE elements of type src_unit/dest_unit + ($name:ident, $src_unit:ty, $dst_unit:ty, $stride_neither_aligned:ident) => { + /// Safety: src and dst must be valid for reads/writes of len elements of type src_unit/dst_unit + /// + /// Safety-usable invariant: will return Some() when it encounters non-ASCII, with the first element in the Some being + /// guaranteed to be non-ASCII (> 127), and the second being the offset where it is found + #[inline(always)] + pub unsafe fn $name( + src: *const $src_unit, + dst: *mut $dst_unit, + len: usize, + ) -> Option<($src_unit, usize)> { + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `stride` elements. + if SIMD_STRIDE_SIZE <= len { + let len_minus_stride = len - SIMD_STRIDE_SIZE; + loop { + // Safety: We know we're valid for `stride` reads/writes, so we can call this function. We don't need alignment. + if !$stride_neither_aligned(src.add(offset), dst.add(offset)) { + break; + } + offset += SIMD_STRIDE_SIZE; + // This is `offset > len - stride` which means we always have at least `stride` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + while offset < len { + // Safety: Uses len invariant here and below + let code_unit = *(src.add(offset)); + if code_unit > 127 { + // Safety-usable invariant upheld here + return Some((code_unit, offset)); + } + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! latin1_simd_unalign { + // Safety: stride_neither_aligned must be a function that requires src/dest be valid for unaligned reads/writes for SIMD_STRIDE_SIZE elements of type src_unit/dest_unit + ($name:ident, $src_unit:ty, $dst_unit:ty, $stride_neither_aligned:ident) => { + /// Safety: src and dst must be valid for unaligned reads/writes of len elements of type src_unit/dst_unit + #[inline(always)] + pub unsafe fn $name(src: *const $src_unit, dst: *mut $dst_unit, len: usize) { + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `stride` elements. + if SIMD_STRIDE_SIZE <= len { + let len_minus_stride = len - SIMD_STRIDE_SIZE; + loop { + // Safety: We know we're valid for `stride` reads/writes, so we can call this function. We don't need alignment. + $stride_neither_aligned(src.add(offset), dst.add(offset)); + offset += SIMD_STRIDE_SIZE; + // This is `offset > len - stride` which means we always have at least `stride` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + while offset < len { + // Safety: Uses len invariant here + let code_unit = *(src.add(offset)); + *(dst.add(offset)) = code_unit as $dst_unit; + offset += 1; + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_to_ascii_simd_stride { + // Safety: load/store must be valid for 16 bytes of read/write, which may be unaligned. (candidates: `(load|store)(16|8)_(unaligned|aligned)` functions) + ($name:ident, $load:ident, $store:ident) => { + /// Safety: src and dst must be valid for 16 bytes of read/write according to + /// the $load/$store fn, which may allow for unaligned reads/writes or require + /// alignment to either 16x8 or u8x16. + #[inline(always)] + pub unsafe fn $name(src: *const u8, dst: *mut u8) -> bool { + let simd = $load(src); + if !simd_is_ascii(simd) { + return false; + } + $store(dst, simd); + true + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_to_ascii_simd_double_stride { + // Safety: store must be valid for 32 bytes of write, which may be unaligned (candidates: `store(8|16)_(aligned|unaligned)`) + ($name:ident, $store:ident) => { + /// Safety: src must be valid for 32 bytes of aligned u8x16 read + /// dst must be valid for 32 bytes of unaligned write according to + /// the $store fn, which may allow for unaligned writes or require + /// alignment to either 16x8 or u8x16. + /// + /// Safety-usable invariant: Returns Some(index) if the element at `index` is invalid ASCII + #[inline(always)] + pub unsafe fn $name(src: *const u8, dst: *mut u8) -> Option { + let first = load16_aligned(src); + let second = load16_aligned(src.add(SIMD_STRIDE_SIZE)); + $store(dst, first); + if unlikely(!simd_is_ascii(first | second)) { + // Safety: mask_ascii produces a mask of all the high bits. + let mask_first = mask_ascii(first); + if mask_first != 0 { + // Safety: on little endian systems this will be the number of ascii bytes + // before the first non-ascii, i.e. valid for indexing src + // TODO SAFETY: What about big-endian systems? + return Some(mask_first.trailing_zeros() as usize); + } + $store(dst.add(SIMD_STRIDE_SIZE), second); + let mask_second = mask_ascii(second); + // Safety: on little endian systems this will be the number of ascii bytes + // before the first non-ascii, i.e. valid for indexing src + return Some(SIMD_STRIDE_SIZE + mask_second.trailing_zeros() as usize); + } + $store(dst.add(SIMD_STRIDE_SIZE), second); + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_to_basic_latin_simd_stride { + // Safety: load/store must be valid for 16 bytes of read/write, which may be unaligned. (candidates: `(load|store)(16|8)_(unaligned|aligned)` functions) + ($name:ident, $load:ident, $store:ident) => { + /// Safety: src and dst must be valid for 16/32 bytes of read/write according to + /// the $load/$store fn, which may allow for unaligned reads/writes or require + /// alignment to either 16x8 or u8x16. + #[inline(always)] + pub unsafe fn $name(src: *const u8, dst: *mut u16) -> bool { + let simd = $load(src); + if !simd_is_ascii(simd) { + return false; + } + let (first, second) = simd_unpack(simd); + $store(dst, first); + $store(dst.add(8), second); + true + } + }; +} + +#[allow(unused_macros)] +macro_rules! ascii_to_basic_latin_simd_double_stride { + // Safety: store must be valid for 16 bytes of write, which may be unaligned + ($name:ident, $store:ident) => { + /// Safety: src must be valid for 2*SIMD_STRIDE_SIZE bytes of aligned reads, + /// aligned to either 16x8 or u8x16. + /// dst must be valid for 2*SIMD_STRIDE_SIZE bytes of aligned or unaligned reads + #[inline(always)] + pub unsafe fn $name(src: *const u8, dst: *mut u16) -> Option { + let first = load16_aligned(src); + let second = load16_aligned(src.add(SIMD_STRIDE_SIZE)); + let (a, b) = simd_unpack(first); + $store(dst, a); + // Safety: divide by 2 since it's a u16 pointer + $store(dst.add(SIMD_STRIDE_SIZE / 2), b); + if unlikely(!simd_is_ascii(first | second)) { + let mask_first = mask_ascii(first); + if mask_first != 0 { + return Some(mask_first.trailing_zeros() as usize); + } + let (c, d) = simd_unpack(second); + $store(dst.add(SIMD_STRIDE_SIZE), c); + $store(dst.add(SIMD_STRIDE_SIZE + (SIMD_STRIDE_SIZE / 2)), d); + let mask_second = mask_ascii(second); + return Some(SIMD_STRIDE_SIZE + mask_second.trailing_zeros() as usize); + } + let (c, d) = simd_unpack(second); + $store(dst.add(SIMD_STRIDE_SIZE), c); + $store(dst.add(SIMD_STRIDE_SIZE + (SIMD_STRIDE_SIZE / 2)), d); + None + } + }; +} + +#[allow(unused_macros)] +macro_rules! unpack_simd_stride { + // Safety: load/store must be valid for 16 bytes of read/write, which may be unaligned. (candidates: `(load|store)(16|8)_(unaligned|aligned)` functions) + ($name:ident, $load:ident, $store:ident) => { + /// Safety: src and dst must be valid for 16 bytes of read/write according to + /// the $load/$store fn, which may allow for unaligned reads/writes or require + /// alignment to either 16x8 or u8x16. + #[inline(always)] + pub unsafe fn $name(src: *const u8, dst: *mut u16) { + let simd = $load(src); + let (first, second) = simd_unpack(simd); + $store(dst, first); + $store(dst.add(8), second); + } + }; +} + +#[allow(unused_macros)] +macro_rules! basic_latin_to_ascii_simd_stride { + // Safety: load/store must be valid for 16 bytes of read/write, which may be unaligned. (candidates: `(load|store)(16|8)_(unaligned|aligned)` functions) + ($name:ident, $load:ident, $store:ident) => { + /// Safety: src and dst must be valid for 32/16 bytes of read/write according to + /// the $load/$store fn, which may allow for unaligned reads/writes or require + /// alignment to either 16x8 or u8x16. + #[inline(always)] + pub unsafe fn $name(src: *const u16, dst: *mut u8) -> bool { + let first = $load(src); + let second = $load(src.add(8)); + if simd_is_basic_latin(first | second) { + $store(dst, simd_pack(first, second)); + true + } else { + false + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! pack_simd_stride { + // Safety: load/store must be valid for 16 bytes of read/write, which may be unaligned. (candidates: `(load|store)(16|8)_(unaligned|aligned)` functions) + ($name:ident, $load:ident, $store:ident) => { + /// Safety: src and dst must be valid for 32/16 bytes of read/write according to + /// the $load/$store fn, which may allow for unaligned reads/writes or require + /// alignment to either 16x8 or u8x16. + #[inline(always)] + pub unsafe fn $name(src: *const u16, dst: *mut u8) { + let first = $load(src); + let second = $load(src.add(8)); + $store(dst, simd_pack(first, second)); + } + }; +} + +cfg_if! { + if #[cfg(all(target_endian = "little", target_arch = "aarch64"))] { + // SIMD with the same instructions for aligned and unaligned loads and stores + + pub const SIMD_STRIDE_SIZE: usize = 16; + + pub const MAX_STRIDE_SIZE: usize = 16; + +// pub const ALIGNMENT: usize = 8; + + pub const ALU_STRIDE_SIZE: usize = 16; + + pub const ALU_ALIGNMENT: usize = 8; + + pub const ALU_ALIGNMENT_MASK: usize = 7; + + // Safety for stride macros: We stick to the load8_aligned/etc family of functions. We consistently produce + // neither_unaligned variants using only unaligned inputs. + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_neither_aligned, load16_unaligned, store16_unaligned); + + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_neither_aligned, load16_unaligned, store8_unaligned); + unpack_simd_stride!(unpack_stride_neither_aligned, load16_unaligned, store8_unaligned); + + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_neither_aligned, load8_unaligned, store16_unaligned); + pack_simd_stride!(pack_stride_neither_aligned, load8_unaligned, store16_unaligned); + + // Safety for conversion macros: We use the unalign macro with unalign functions above. All stride functions were produced + // by stride macros that universally munch a single SIMD_STRIDE_SIZE worth of elements. + ascii_simd_unalign!(ascii_to_ascii, u8, u8, ascii_to_ascii_stride_neither_aligned); + ascii_simd_unalign!(ascii_to_basic_latin, u8, u16, ascii_to_basic_latin_stride_neither_aligned); + ascii_simd_unalign!(basic_latin_to_ascii, u16, u8, basic_latin_to_ascii_stride_neither_aligned); + latin1_simd_unalign!(unpack_latin1, u8, u16, unpack_stride_neither_aligned); + latin1_simd_unalign!(pack_latin1, u16, u8, pack_stride_neither_aligned); + } else if #[cfg(all(target_endian = "little", target_feature = "neon"))] { + // SIMD with different instructions for aligned and unaligned loads and stores. + // + // Newer microarchitectures are not supposed to have a performance difference between + // aligned and unaligned SSE2 loads and stores when the address is actually aligned, + // but the benchmark results I see don't agree. + + pub const SIMD_STRIDE_SIZE: usize = 16; + + pub const MAX_STRIDE_SIZE: usize = 16; + + pub const SIMD_ALIGNMENT_MASK: usize = 15; + + // Safety for stride macros: We stick to the load8_aligned/etc family of functions. We consistently name + // aligned/unaligned functions according to src/dst being aligned/unaligned + + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_both_aligned, load16_aligned, store16_aligned); + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_src_aligned, load16_aligned, store16_unaligned); + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_dst_aligned, load16_unaligned, store16_aligned); + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_neither_aligned, load16_unaligned, store16_unaligned); + + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_both_aligned, load16_aligned, store8_aligned); + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_src_aligned, load16_aligned, store8_unaligned); + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_dst_aligned, load16_unaligned, store8_aligned); + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_neither_aligned, load16_unaligned, store8_unaligned); + + unpack_simd_stride!(unpack_stride_both_aligned, load16_aligned, store8_aligned); + unpack_simd_stride!(unpack_stride_src_aligned, load16_aligned, store8_unaligned); + unpack_simd_stride!(unpack_stride_dst_aligned, load16_unaligned, store8_aligned); + unpack_simd_stride!(unpack_stride_neither_aligned, load16_unaligned, store8_unaligned); + + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_both_aligned, load8_aligned, store16_aligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_src_aligned, load8_aligned, store16_unaligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_dst_aligned, load8_unaligned, store16_aligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_neither_aligned, load8_unaligned, store16_unaligned); + + pack_simd_stride!(pack_stride_both_aligned, load8_aligned, store16_aligned); + pack_simd_stride!(pack_stride_src_aligned, load8_aligned, store16_unaligned); + pack_simd_stride!(pack_stride_dst_aligned, load8_unaligned, store16_aligned); + pack_simd_stride!(pack_stride_neither_aligned, load8_unaligned, store16_unaligned); + + // Safety for conversion macros: We use the correct pattern of both/src/dst/neither here. All stride functions were produced + // by stride macros that universally munch a single SIMD_STRIDE_SIZE worth of elements. + + ascii_simd_check_align!(ascii_to_ascii, u8, u8, ascii_to_ascii_stride_both_aligned, ascii_to_ascii_stride_src_aligned, ascii_to_ascii_stride_dst_aligned, ascii_to_ascii_stride_neither_aligned); + ascii_simd_check_align!(ascii_to_basic_latin, u8, u16, ascii_to_basic_latin_stride_both_aligned, ascii_to_basic_latin_stride_src_aligned, ascii_to_basic_latin_stride_dst_aligned, ascii_to_basic_latin_stride_neither_aligned); + ascii_simd_check_align!(basic_latin_to_ascii, u16, u8, basic_latin_to_ascii_stride_both_aligned, basic_latin_to_ascii_stride_src_aligned, basic_latin_to_ascii_stride_dst_aligned, basic_latin_to_ascii_stride_neither_aligned); + latin1_simd_check_align!(unpack_latin1, u8, u16, unpack_stride_both_aligned, unpack_stride_src_aligned, unpack_stride_dst_aligned, unpack_stride_neither_aligned); + latin1_simd_check_align!(pack_latin1, u16, u8, pack_stride_both_aligned, pack_stride_src_aligned, pack_stride_dst_aligned, pack_stride_neither_aligned); + } else if #[cfg(target_feature = "sse2")] { + // SIMD with different instructions for aligned and unaligned loads and stores. + // + // Newer microarchitectures are not supposed to have a performance difference between + // aligned and unaligned SSE2 loads and stores when the address is actually aligned, + // but the benchmark results I see don't agree. + + pub const SIMD_STRIDE_SIZE: usize = 16; + + /// Safety-usable invariant: This should be identical to SIMD_STRIDE_SIZE (used by ascii_simd_check_align_unrolled) + pub const SIMD_ALIGNMENT: usize = 16; + + pub const MAX_STRIDE_SIZE: usize = 16; + + pub const SIMD_ALIGNMENT_MASK: usize = 15; + + // Safety for stride macros: We stick to the load8_aligned/etc family of functions. We consistently name + // aligned/unaligned functions according to src/dst being aligned/unaligned + + ascii_to_ascii_simd_double_stride!(ascii_to_ascii_simd_double_stride_both_aligned, store16_aligned); + ascii_to_ascii_simd_double_stride!(ascii_to_ascii_simd_double_stride_src_aligned, store16_unaligned); + + ascii_to_basic_latin_simd_double_stride!(ascii_to_basic_latin_simd_double_stride_both_aligned, store8_aligned); + ascii_to_basic_latin_simd_double_stride!(ascii_to_basic_latin_simd_double_stride_src_aligned, store8_unaligned); + + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_both_aligned, load16_aligned, store16_aligned); + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_src_aligned, load16_aligned, store16_unaligned); + ascii_to_ascii_simd_stride!(ascii_to_ascii_stride_neither_aligned, load16_unaligned, store16_unaligned); + + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_both_aligned, load16_aligned, store8_aligned); + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_src_aligned, load16_aligned, store8_unaligned); + ascii_to_basic_latin_simd_stride!(ascii_to_basic_latin_stride_neither_aligned, load16_unaligned, store8_unaligned); + + unpack_simd_stride!(unpack_stride_both_aligned, load16_aligned, store8_aligned); + unpack_simd_stride!(unpack_stride_src_aligned, load16_aligned, store8_unaligned); + + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_both_aligned, load8_aligned, store16_aligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_src_aligned, load8_aligned, store16_unaligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_dst_aligned, load8_unaligned, store16_aligned); + basic_latin_to_ascii_simd_stride!(basic_latin_to_ascii_stride_neither_aligned, load8_unaligned, store16_unaligned); + + pack_simd_stride!(pack_stride_both_aligned, load8_aligned, store16_aligned); + pack_simd_stride!(pack_stride_src_aligned, load8_aligned, store16_unaligned); + + // Safety for conversion macros: We use the correct pattern of both/src/dst/neither/double_both/double_src here. All stride functions were produced + // by stride macros that universally munch a single SIMD_STRIDE_SIZE worth of elements. + + ascii_simd_check_align_unrolled!(ascii_to_ascii, u8, u8, ascii_to_ascii_stride_both_aligned, ascii_to_ascii_stride_src_aligned, ascii_to_ascii_stride_neither_aligned, ascii_to_ascii_simd_double_stride_both_aligned, ascii_to_ascii_simd_double_stride_src_aligned); + ascii_simd_check_align_unrolled!(ascii_to_basic_latin, u8, u16, ascii_to_basic_latin_stride_both_aligned, ascii_to_basic_latin_stride_src_aligned, ascii_to_basic_latin_stride_neither_aligned, ascii_to_basic_latin_simd_double_stride_both_aligned, ascii_to_basic_latin_simd_double_stride_src_aligned); + + ascii_simd_check_align!(basic_latin_to_ascii, u16, u8, basic_latin_to_ascii_stride_both_aligned, basic_latin_to_ascii_stride_src_aligned, basic_latin_to_ascii_stride_dst_aligned, basic_latin_to_ascii_stride_neither_aligned); + latin1_simd_check_align_unrolled!(unpack_latin1, u8, u16, unpack_stride_both_aligned, unpack_stride_src_aligned, unpack_stride_dst_aligned, unpack_stride_neither_aligned); + latin1_simd_check_align_unrolled!(pack_latin1, u16, u8, pack_stride_both_aligned, pack_stride_src_aligned, pack_stride_dst_aligned, pack_stride_neither_aligned); + } else if #[cfg(all(target_endian = "little", target_pointer_width = "64"))] { + // Aligned ALU word, little-endian, 64-bit + + /// Safety invariant: this is the amount of bytes consumed by + /// unpack_alu. This will be twice the pointer width, as it consumes two usizes. + /// This is also the number of bytes produced by pack_alu. + /// This is also the number of u16 code units produced/consumed by unpack_alu/pack_alu respectively. + pub const ALU_STRIDE_SIZE: usize = 16; + + pub const MAX_STRIDE_SIZE: usize = 16; + + // Safety invariant: this is the pointer width in bytes + pub const ALU_ALIGNMENT: usize = 8; + + // Safety invariant: this is a mask for getting the bits of a pointer not aligned to ALU_ALIGNMENT + pub const ALU_ALIGNMENT_MASK: usize = 7; + + /// Safety: dst must point to valid space for writing four `usize`s + #[inline(always)] + unsafe fn unpack_alu(word: usize, second_word: usize, dst: *mut usize) { + let first = ((0x0000_0000_FF00_0000usize & word) << 24) | + ((0x0000_0000_00FF_0000usize & word) << 16) | + ((0x0000_0000_0000_FF00usize & word) << 8) | + (0x0000_0000_0000_00FFusize & word); + let second = ((0xFF00_0000_0000_0000usize & word) >> 8) | + ((0x00FF_0000_0000_0000usize & word) >> 16) | + ((0x0000_FF00_0000_0000usize & word) >> 24) | + ((0x0000_00FF_0000_0000usize & word) >> 32); + let third = ((0x0000_0000_FF00_0000usize & second_word) << 24) | + ((0x0000_0000_00FF_0000usize & second_word) << 16) | + ((0x0000_0000_0000_FF00usize & second_word) << 8) | + (0x0000_0000_0000_00FFusize & second_word); + let fourth = ((0xFF00_0000_0000_0000usize & second_word) >> 8) | + ((0x00FF_0000_0000_0000usize & second_word) >> 16) | + ((0x0000_FF00_0000_0000usize & second_word) >> 24) | + ((0x0000_00FF_0000_0000usize & second_word) >> 32); + // Safety: fn invariant used here + *dst = first; + *(dst.add(1)) = second; + *(dst.add(2)) = third; + *(dst.add(3)) = fourth; + } + + /// Safety: dst must point to valid space for writing two `usize`s + #[inline(always)] + unsafe fn pack_alu(first: usize, second: usize, third: usize, fourth: usize, dst: *mut usize) { + let word = ((0x00FF_0000_0000_0000usize & second) << 8) | + ((0x0000_00FF_0000_0000usize & second) << 16) | + ((0x0000_0000_00FF_0000usize & second) << 24) | + ((0x0000_0000_0000_00FFusize & second) << 32) | + ((0x00FF_0000_0000_0000usize & first) >> 24) | + ((0x0000_00FF_0000_0000usize & first) >> 16) | + ((0x0000_0000_00FF_0000usize & first) >> 8) | + (0x0000_0000_0000_00FFusize & first); + let second_word = ((0x00FF_0000_0000_0000usize & fourth) << 8) | + ((0x0000_00FF_0000_0000usize & fourth) << 16) | + ((0x0000_0000_00FF_0000usize & fourth) << 24) | + ((0x0000_0000_0000_00FFusize & fourth) << 32) | + ((0x00FF_0000_0000_0000usize & third) >> 24) | + ((0x0000_00FF_0000_0000usize & third) >> 16) | + ((0x0000_0000_00FF_0000usize & third) >> 8) | + (0x0000_0000_0000_00FFusize & third); + // Safety: fn invariant used here + *dst = word; + *(dst.add(1)) = second_word; + } + } else if #[cfg(all(target_endian = "little", target_pointer_width = "32"))] { + // Aligned ALU word, little-endian, 32-bit + + /// Safety invariant: this is the amount of bytes consumed by + /// unpack_alu. This will be twice the pointer width, as it consumes two usizes. + /// This is also the number of bytes produced by pack_alu. + /// This is also the number of u16 code units produced/consumed by unpack_alu/pack_alu respectively. + pub const ALU_STRIDE_SIZE: usize = 8; + + pub const MAX_STRIDE_SIZE: usize = 8; + + // Safety invariant: this is the pointer width in bytes + pub const ALU_ALIGNMENT: usize = 4; + + // Safety invariant: this is a mask for getting the bits of a pointer not aligned to ALU_ALIGNMENT + pub const ALU_ALIGNMENT_MASK: usize = 3; + + /// Safety: dst must point to valid space for writing four `usize`s + #[inline(always)] + unsafe fn unpack_alu(word: usize, second_word: usize, dst: *mut usize) { + let first = ((0x0000_FF00usize & word) << 8) | + (0x0000_00FFusize & word); + let second = ((0xFF00_0000usize & word) >> 8) | + ((0x00FF_0000usize & word) >> 16); + let third = ((0x0000_FF00usize & second_word) << 8) | + (0x0000_00FFusize & second_word); + let fourth = ((0xFF00_0000usize & second_word) >> 8) | + ((0x00FF_0000usize & second_word) >> 16); + // Safety: fn invariant used here + *dst = first; + *(dst.add(1)) = second; + *(dst.add(2)) = third; + *(dst.add(3)) = fourth; + } + + /// Safety: dst must point to valid space for writing two `usize`s + #[inline(always)] + unsafe fn pack_alu(first: usize, second: usize, third: usize, fourth: usize, dst: *mut usize) { + let word = ((0x00FF_0000usize & second) << 8) | + ((0x0000_00FFusize & second) << 16) | + ((0x00FF_0000usize & first) >> 8) | + (0x0000_00FFusize & first); + let second_word = ((0x00FF_0000usize & fourth) << 8) | + ((0x0000_00FFusize & fourth) << 16) | + ((0x00FF_0000usize & third) >> 8) | + (0x0000_00FFusize & third); + // Safety: fn invariant used here + *dst = word; + *(dst.add(1)) = second_word; + } + } else if #[cfg(all(target_endian = "big", target_pointer_width = "64"))] { + // Aligned ALU word, big-endian, 64-bit + + /// Safety invariant: this is the amount of bytes consumed by + /// unpack_alu. This will be twice the pointer width, as it consumes two usizes. + /// This is also the number of bytes produced by pack_alu. + /// This is also the number of u16 code units produced/consumed by unpack_alu/pack_alu respectively. + pub const ALU_STRIDE_SIZE: usize = 16; + + pub const MAX_STRIDE_SIZE: usize = 16; + + // Safety invariant: this is the pointer width in bytes + pub const ALU_ALIGNMENT: usize = 8; + + // Safety invariant: this is a mask for getting the bits of a pointer not aligned to ALU_ALIGNMENT + pub const ALU_ALIGNMENT_MASK: usize = 7; + + /// Safety: dst must point to valid space for writing four `usize`s + #[inline(always)] + unsafe fn unpack_alu(word: usize, second_word: usize, dst: *mut usize) { + let first = ((0xFF00_0000_0000_0000usize & word) >> 8) | + ((0x00FF_0000_0000_0000usize & word) >> 16) | + ((0x0000_FF00_0000_0000usize & word) >> 24) | + ((0x0000_00FF_0000_0000usize & word) >> 32); + let second = ((0x0000_0000_FF00_0000usize & word) << 24) | + ((0x0000_0000_00FF_0000usize & word) << 16) | + ((0x0000_0000_0000_FF00usize & word) << 8) | + (0x0000_0000_0000_00FFusize & word); + let third = ((0xFF00_0000_0000_0000usize & second_word) >> 8) | + ((0x00FF_0000_0000_0000usize & second_word) >> 16) | + ((0x0000_FF00_0000_0000usize & second_word) >> 24) | + ((0x0000_00FF_0000_0000usize & second_word) >> 32); + let fourth = ((0x0000_0000_FF00_0000usize & second_word) << 24) | + ((0x0000_0000_00FF_0000usize & second_word) << 16) | + ((0x0000_0000_0000_FF00usize & second_word) << 8) | + (0x0000_0000_0000_00FFusize & second_word); + // Safety: fn invariant used here + *dst = first; + *(dst.add(1)) = second; + *(dst.add(2)) = third; + *(dst.add(3)) = fourth; + } + + /// Safety: dst must point to valid space for writing two `usize`s + #[inline(always)] + unsafe fn pack_alu(first: usize, second: usize, third: usize, fourth: usize, dst: *mut usize) { + let word = ((0x00FF0000_00000000usize & first) << 8) | + ((0x000000FF_00000000usize & first) << 16) | + ((0x00000000_00FF0000usize & first) << 24) | + ((0x00000000_000000FFusize & first) << 32) | + ((0x00FF0000_00000000usize & second) >> 24) | + ((0x000000FF_00000000usize & second) >> 16) | + ((0x00000000_00FF0000usize & second) >> 8) | + (0x00000000_000000FFusize & second); + let second_word = ((0x00FF0000_00000000usize & third) << 8) | + ((0x000000FF_00000000usize & third) << 16) | + ((0x00000000_00FF0000usize & third) << 24) | + ((0x00000000_000000FFusize & third) << 32) | + ((0x00FF0000_00000000usize & fourth) >> 24) | + ((0x000000FF_00000000usize & fourth) >> 16) | + ((0x00000000_00FF0000usize & fourth) >> 8) | + (0x00000000_000000FFusize & fourth); + // Safety: fn invariant used here + *dst = word; + *(dst.add(1)) = second_word; + } + } else if #[cfg(all(target_endian = "big", target_pointer_width = "32"))] { + // Aligned ALU word, big-endian, 32-bit + + /// Safety invariant: this is the amount of bytes consumed by + /// unpack_alu. This will be twice the pointer width, as it consumes two usizes. + /// This is also the number of bytes produced by pack_alu. + /// This is also the number of u16 code units produced/consumed by unpack_alu/pack_alu respectively. + pub const ALU_STRIDE_SIZE: usize = 8; + + pub const MAX_STRIDE_SIZE: usize = 8; + + // Safety invariant: this is the pointer width in bytes + pub const ALU_ALIGNMENT: usize = 4; + + // Safety invariant: this is a mask for getting the bits of a pointer not aligned to ALU_ALIGNMENT + pub const ALU_ALIGNMENT_MASK: usize = 3; + + /// Safety: dst must point to valid space for writing four `usize`s + #[inline(always)] + unsafe fn unpack_alu(word: usize, second_word: usize, dst: *mut usize) { + let first = ((0xFF00_0000usize & word) >> 8) | + ((0x00FF_0000usize & word) >> 16); + let second = ((0x0000_FF00usize & word) << 8) | + (0x0000_00FFusize & word); + let third = ((0xFF00_0000usize & second_word) >> 8) | + ((0x00FF_0000usize & second_word) >> 16); + let fourth = ((0x0000_FF00usize & second_word) << 8) | + (0x0000_00FFusize & second_word); + // Safety: fn invariant used here + *dst = first; + *(dst.add(1)) = second; + *(dst.add(2)) = third; + *(dst.add(3)) = fourth; + } + + /// Safety: dst must point to valid space for writing two `usize`s + #[inline(always)] + unsafe fn pack_alu(first: usize, second: usize, third: usize, fourth: usize, dst: *mut usize) { + let word = ((0x00FF_0000usize & first) << 8) | + ((0x0000_00FFusize & first) << 16) | + ((0x00FF_0000usize & second) >> 8) | + (0x0000_00FFusize & second); + let second_word = ((0x00FF_0000usize & third) << 8) | + ((0x0000_00FFusize & third) << 16) | + ((0x00FF_0000usize & fourth) >> 8) | + (0x0000_00FFusize & fourth); + // Safety: fn invariant used here + *dst = word; + *(dst.add(1)) = second_word; + } + } else { + ascii_naive!(ascii_to_ascii, u8, u8); + ascii_naive!(ascii_to_basic_latin, u8, u16); + ascii_naive!(basic_latin_to_ascii, u16, u8); + } +} + +cfg_if! { + // Safety-usable invariant: this counts the zeroes from the "first byte" of utf-8 data packed into a usize + // with the target endianness + if #[cfg(target_endian = "little")] { + #[allow(dead_code)] + #[inline(always)] + fn count_zeros(word: usize) -> u32 { + word.trailing_zeros() + } + } else { + #[allow(dead_code)] + #[inline(always)] + fn count_zeros(word: usize) -> u32 { + word.leading_zeros() + } + } +} + +cfg_if! { + if #[cfg(all(target_endian = "little", target_arch = "aarch64"))] { + /// Safety-usable invariant: Will return the value and position of the first non-ASCII byte in the slice in a Some if found. + /// In other words, the first element of the Some is always `> 127` + #[inline(always)] + pub fn validate_ascii(slice: &[u8]) -> Option<(u8, usize)> { + let src = slice.as_ptr(); + let len = slice.len(); + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading/writing at least `stride` elements. + if SIMD_STRIDE_SIZE <= len { + let len_minus_stride = len - SIMD_STRIDE_SIZE; + loop { + // Safety: src at offset is valid for a `SIMD_STRIDE_SIZE` read + let simd = unsafe { load16_unaligned(src.add(offset)) }; + if !simd_is_ascii(simd) { + break; + } + offset += SIMD_STRIDE_SIZE; + // This is `offset > len - SIMD_STRIDE_SIZE` which means we always have at least `SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + while offset < len { + let code_unit = slice[offset]; + if code_unit > 127 { + // Safety: Safety-usable invariant upheld here + return Some((code_unit, offset)); + } + offset += 1; + } + None + } + } else if #[cfg(target_feature = "sse2")] { + /// Safety-usable invariant: will return Some() when it encounters non-ASCII, with the first element in the Some being + /// guaranteed to be non-ASCII (> 127), and the second being the offset where it is found + #[inline(always)] + pub fn validate_ascii(slice: &[u8]) -> Option<(u8, usize)> { + let src = slice.as_ptr(); + let len = slice.len(); + let mut offset = 0usize; + // Safety: if this check succeeds we're valid for reading at least `stride` elements. + if SIMD_STRIDE_SIZE <= len { + // First, process one unaligned vector + // Safety: src is valid for a `SIMD_STRIDE_SIZE` read + let simd = unsafe { load16_unaligned(src) }; + let mask = mask_ascii(simd); + if mask != 0 { + offset = mask.trailing_zeros() as usize; + let non_ascii = unsafe { *src.add(offset) }; + return Some((non_ascii, offset)); + } + offset = SIMD_STRIDE_SIZE; + // Safety: Now that offset has changed we don't yet know how much it is valid for + + // We have now seen 16 ASCII bytes. Let's guess that + // there will be enough more to justify more expense + // in the case of non-ASCII. + // Use aligned reads for the sake of old microachitectures. + // Safety: this correctly calculates the number of src_units that need to be read before the remaining list is aligned. + // This is by definition less than SIMD_ALIGNMENT, which is defined to be equal to SIMD_STRIDE_SIZE. + let until_alignment = unsafe { (SIMD_ALIGNMENT - ((src.add(offset) as usize) & SIMD_ALIGNMENT_MASK)) & SIMD_ALIGNMENT_MASK }; + // This addition won't overflow, because even in the 32-bit PAE case the + // address space holds enough code that the slice length can't be that + // close to address space size. + // offset now equals SIMD_STRIDE_SIZE, hence times 3 below. + // + // Safety: if this check succeeds we're valid for reading at least `2 * SIMD_STRIDE_SIZE` elements plus `until_alignment`. + // The extra SIMD_STRIDE_SIZE in the condition is because `offset` is already `SIMD_STRIDE_SIZE`. + if until_alignment + (SIMD_STRIDE_SIZE * 3) <= len { + if until_alignment != 0 { + // Safety: this is safe to call since we're valid for this read (and more), and don't care about alignment + // This will copy over bytes that get decoded twice since it's not incrementing `offset` by SIMD_STRIDE_SIZE. This is fine. + let simd = unsafe { load16_unaligned(src.add(offset)) }; + let mask = mask_ascii(simd); + if mask != 0 { + offset += mask.trailing_zeros() as usize; + let non_ascii = unsafe { *src.add(offset) }; + return Some((non_ascii, offset)); + } + offset += until_alignment; + } + // Safety: At this point we're valid for reading 2*SIMD_STRIDE_SIZE elements + // Safety: Now `offset` is aligned for `src` + let len_minus_stride_times_two = len - (SIMD_STRIDE_SIZE * 2); + loop { + // Safety: We were valid for this read, and were aligned. + let first = unsafe { load16_aligned(src.add(offset)) }; + let second = unsafe { load16_aligned(src.add(offset + SIMD_STRIDE_SIZE)) }; + if !simd_is_ascii(first | second) { + // Safety: mask_ascii produces a mask of all the high bits. + let mask_first = mask_ascii(first); + if mask_first != 0 { + // Safety: on little endian systems this will be the number of ascii bytes + // before the first non-ascii, i.e. valid for indexing src + // TODO SAFETY: What about big-endian systems? + offset += mask_first.trailing_zeros() as usize; + } else { + let mask_second = mask_ascii(second); + // Safety: on little endian systems this will be the number of ascii bytes + // before the first non-ascii, i.e. valid for indexing src + offset += SIMD_STRIDE_SIZE + mask_second.trailing_zeros() as usize; + } + // Safety: We know this is non-ASCII, and can uphold the safety-usable invariant here + let non_ascii = unsafe { *src.add(offset) }; + + return Some((non_ascii, offset)); + } + offset += SIMD_STRIDE_SIZE * 2; + // Safety: This is `offset > len - 2 * SIMD_STRIDE_SIZE` which means we always have at least `2 * SIMD_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride_times_two { + break; + } + } + // Safety: if this check succeeds we're valid for reading at least `SIMD_STRIDE_SIZE` + if offset + SIMD_STRIDE_SIZE <= len { + // Safety: We were valid for this read, and were aligned. + let simd = unsafe { load16_aligned(src.add(offset)) }; + // Safety: mask_ascii produces a mask of all the high bits. + let mask = mask_ascii(simd); + if mask != 0 { + // Safety: on little endian systems this will be the number of ascii bytes + // before the first non-ascii, i.e. valid for indexing src + offset += mask.trailing_zeros() as usize; + let non_ascii = unsafe { *src.add(offset) }; + // Safety: We know this is non-ASCII, and can uphold the safety-usable invariant here + return Some((non_ascii, offset)); + } + offset += SIMD_STRIDE_SIZE; + } + } else { + // Safety: this is the unaligned branch + // At most two iterations, so unroll + // Safety: if this check succeeds we're valid for reading at least `SIMD_STRIDE_SIZE` + if offset + SIMD_STRIDE_SIZE <= len { + // Safety: We're valid for this read but must use an unaligned read + let simd = unsafe { load16_unaligned(src.add(offset)) }; + let mask = mask_ascii(simd); + if mask != 0 { + offset += mask.trailing_zeros() as usize; + let non_ascii = unsafe { *src.add(offset) }; + // Safety-usable invariant upheld here (same as above) + return Some((non_ascii, offset)); + } + offset += SIMD_STRIDE_SIZE; + // Safety: if this check succeeds we're valid for reading at least `SIMD_STRIDE_SIZE` + if offset + SIMD_STRIDE_SIZE <= len { + // Safety: We're valid for this read but must use an unaligned read + let simd = unsafe { load16_unaligned(src.add(offset)) }; + let mask = mask_ascii(simd); + if mask != 0 { + offset += mask.trailing_zeros() as usize; + let non_ascii = unsafe { *src.add(offset) }; + // Safety-usable invariant upheld here (same as above) + return Some((non_ascii, offset)); + } + offset += SIMD_STRIDE_SIZE; + } + } + } + } + while offset < len { + // Safety: relies straightforwardly on the `len` invariant + let code_unit = unsafe { *(src.add(offset)) }; + if code_unit > 127 { + // Safety-usable invariant upheld here + return Some((code_unit, offset)); + } + offset += 1; + } + None + } + } else { + // Safety-usable invariant: returns byte index of first non-ascii byte + #[inline(always)] + fn find_non_ascii(word: usize, second_word: usize) -> Option { + let word_masked = word & ASCII_MASK; + let second_masked = second_word & ASCII_MASK; + if (word_masked | second_masked) == 0 { + // Both are ascii, invariant upheld + return None; + } + if word_masked != 0 { + let zeros = count_zeros(word_masked); + // `zeros` now contains 0 to 7 (for the seven bits of masked ASCII in little endian, + // or up to 7 bits of non-ASCII in big endian if the first byte is non-ASCII) + // plus 8 times the number of ASCII in text order before the + // non-ASCII byte in the little-endian case or 8 times the number of ASCII in + // text order before the non-ASCII byte in the big-endian case. + let num_ascii = (zeros >> 3) as usize; + // Safety-usable invariant upheld here + return Some(num_ascii); + } + let zeros = count_zeros(second_masked); + // `zeros` now contains 0 to 7 (for the seven bits of masked ASCII in little endian, + // or up to 7 bits of non-ASCII in big endian if the first byte is non-ASCII) + // plus 8 times the number of ASCII in text order before the + // non-ASCII byte in the little-endian case or 8 times the number of ASCII in + // text order before the non-ASCII byte in the big-endian case. + let num_ascii = (zeros >> 3) as usize; + // Safety-usable invariant upheld here + Some(ALU_ALIGNMENT + num_ascii) + } + + /// Safety: `src` must be valid for the reads of two `usize`s + /// + /// Safety-usable invariant: will return byte index of first non-ascii byte + #[inline(always)] + unsafe fn validate_ascii_stride(src: *const usize) -> Option { + let word = *src; + let second_word = *(src.add(1)); + find_non_ascii(word, second_word) + } + + /// Safety-usable invariant: will return Some() when it encounters non-ASCII, with the first element in the Some being + /// guaranteed to be non-ASCII (> 127), and the second being the offset where it is found + #[cfg_attr(feature = "cargo-clippy", allow(cast_ptr_alignment))] + #[inline(always)] + pub fn validate_ascii(slice: &[u8]) -> Option<(u8, usize)> { + let src = slice.as_ptr(); + let len = slice.len(); + let mut offset = 0usize; + let mut until_alignment = (ALU_ALIGNMENT - ((src as usize) & ALU_ALIGNMENT_MASK)) & ALU_ALIGNMENT_MASK; + // Safety: If this check fails we're valid to read `until_alignment + ALU_STRIDE_SIZE` elements + if until_alignment + ALU_STRIDE_SIZE <= len { + while until_alignment != 0 { + let code_unit = slice[offset]; + if code_unit > 127 { + // Safety-usable invairant upheld here + return Some((code_unit, offset)); + } + offset += 1; + until_alignment -= 1; + } + // Safety: At this point we have read until_alignment elements and + // are valid for `ALU_STRIDE_SIZE` more. + let len_minus_stride = len - ALU_STRIDE_SIZE; + loop { + // Safety: we were valid for this read + let ptr = unsafe { src.add(offset) as *const usize }; + if let Some(num_ascii) = unsafe { validate_ascii_stride(ptr) } { + offset += num_ascii; + // Safety-usable invairant upheld here using the invariant from validate_ascii_stride() + return Some((unsafe { *(src.add(offset)) }, offset)); + } + offset += ALU_STRIDE_SIZE; + // Safety: This is `offset > ALU_STRIDE_SIZE` which means we always have at least `2 * ALU_STRIDE_SIZE` elements to munch next time. + if offset > len_minus_stride { + break; + } + } + } + while offset < len { + let code_unit = slice[offset]; + if code_unit > 127 { + // Safety-usable invairant upheld here + return Some((code_unit, offset)); + } + offset += 1; + } + None + } + + } +} + +cfg_if! { + if #[cfg(any(target_feature = "sse2", all(target_endian = "little", target_arch = "aarch64")))] { + + } else if #[cfg(all(target_endian = "little", target_feature = "neon"))] { + // Even with NEON enabled, we use the ALU path for ASCII validation, because testing + // on Exynos 5 indicated that using NEON isn't worthwhile where there are only + // vector reads without vector writes. + + pub const ALU_STRIDE_SIZE: usize = 8; + + pub const ALU_ALIGNMENT: usize = 4; + + pub const ALU_ALIGNMENT_MASK: usize = 3; + } else { + // Safety: src points to two valid `usize`s, dst points to four valid `usize`s + #[inline(always)] + unsafe fn unpack_latin1_stride_alu(src: *const usize, dst: *mut usize) { + // Safety: src safety invariant used here + let word = *src; + let second_word = *(src.add(1)); + // Safety: dst safety invariant passed down + unpack_alu(word, second_word, dst); + } + + // Safety: src points to four valid `usize`s, dst points to two valid `usize`s + #[inline(always)] + unsafe fn pack_latin1_stride_alu(src: *const usize, dst: *mut usize) { + // Safety: src safety invariant used here + let first = *src; + let second = *(src.add(1)); + let third = *(src.add(2)); + let fourth = *(src.add(3)); + // Safety: dst safety invariant passed down + pack_alu(first, second, third, fourth, dst); + } + + // Safety: src points to two valid `usize`s, dst points to four valid `usize`s + #[inline(always)] + unsafe fn ascii_to_basic_latin_stride_alu(src: *const usize, dst: *mut usize) -> bool { + // Safety: src safety invariant used here + let word = *src; + let second_word = *(src.add(1)); + // Check if the words contains non-ASCII + if (word & ASCII_MASK) | (second_word & ASCII_MASK) != 0 { + return false; + } + // Safety: dst safety invariant passed down + unpack_alu(word, second_word, dst); + true + } + + // Safety: src points four valid `usize`s, dst points to two valid `usize`s + #[inline(always)] + unsafe fn basic_latin_to_ascii_stride_alu(src: *const usize, dst: *mut usize) -> bool { + // Safety: src safety invariant used here + let first = *src; + let second = *(src.add(1)); + let third = *(src.add(2)); + let fourth = *(src.add(3)); + if (first & BASIC_LATIN_MASK) | (second & BASIC_LATIN_MASK) | (third & BASIC_LATIN_MASK) | (fourth & BASIC_LATIN_MASK) != 0 { + return false; + } + // Safety: dst safety invariant passed down + pack_alu(first, second, third, fourth, dst); + true + } + + // Safety: src, dst both point to two valid `usize`s each + // Safety-usable invariant: Will return byte index of first non-ascii byte. + #[inline(always)] + unsafe fn ascii_to_ascii_stride(src: *const usize, dst: *mut usize) -> Option { + // Safety: src safety invariant used here + let word = *src; + let second_word = *(src.add(1)); + // Safety: src safety invariant used here + *dst = word; + *(dst.add(1)) = second_word; + // Relies on safety-usable invariant here + find_non_ascii(word, second_word) + } + + basic_latin_alu!(ascii_to_basic_latin, u8, u16, ascii_to_basic_latin_stride_alu); + basic_latin_alu!(basic_latin_to_ascii, u16, u8, basic_latin_to_ascii_stride_alu); + latin1_alu!(unpack_latin1, u8, u16, unpack_latin1_stride_alu); + latin1_alu!(pack_latin1, u16, u8, pack_latin1_stride_alu); + // Safety invariant upheld: ascii_to_ascii_stride will return byte index of first non-ascii if found + ascii_alu!(ascii_to_ascii, u8, u8, ascii_to_ascii_stride); + } +} diff --git a/patch/prost-0.14.1/src/encoding/utf8/simd_funcs.rs b/patch/prost-0.14.1/src/encoding/utf8/simd_funcs.rs new file mode 100644 index 0000000..e97415c --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/utf8/simd_funcs.rs @@ -0,0 +1,347 @@ +// Copyright Mozilla Foundation. See the COPYRIGHT +// file at the top-level directory of this distribution. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use any_all_workaround::all_mask16x8; +use any_all_workaround::all_mask8x16; +use any_all_workaround::any_mask16x8; +use any_all_workaround::any_mask8x16; +use core::simd::cmp::SimdPartialEq; +use core::simd::cmp::SimdPartialOrd; +use core::simd::simd_swizzle; +use core::simd::u16x8; +use core::simd::u8x16; +use core::simd::ToBytes; + +use cfg_if::cfg_if; + +// TODO: Migrate unaligned access to stdlib code if/when the RFC +// https://github.com/rust-lang/rfcs/pull/1725 is implemented. + +/// Safety invariant: ptr must be valid for an unaligned read of 16 bytes +#[inline(always)] +pub unsafe fn load16_unaligned(ptr: *const u8) -> u8x16 { + let mut simd = ::core::mem::MaybeUninit::::uninit(); + ::core::ptr::copy_nonoverlapping(ptr, simd.as_mut_ptr() as *mut u8, 16); + // Safety: copied 16 bytes of initialized memory into this, it is now initialized + simd.assume_init() +} + +/// Safety invariant: ptr must be valid for an aligned-for-u8x16 read of 16 bytes +#[allow(dead_code)] +#[inline(always)] +pub unsafe fn load16_aligned(ptr: *const u8) -> u8x16 { + *(ptr as *const u8x16) +} + +/// Safety invariant: ptr must be valid for an unaligned store of 16 bytes +#[inline(always)] +pub unsafe fn store16_unaligned(ptr: *mut u8, s: u8x16) { + ::core::ptr::copy_nonoverlapping(&s as *const u8x16 as *const u8, ptr, 16); +} + +/// Safety invariant: ptr must be valid for an aligned-for-u8x16 store of 16 bytes +#[allow(dead_code)] +#[inline(always)] +pub unsafe fn store16_aligned(ptr: *mut u8, s: u8x16) { + *(ptr as *mut u8x16) = s; +} + +/// Safety invariant: ptr must be valid for an unaligned read of 16 bytes +#[inline(always)] +pub unsafe fn load8_unaligned(ptr: *const u16) -> u16x8 { + let mut simd = ::core::mem::MaybeUninit::::uninit(); + ::core::ptr::copy_nonoverlapping(ptr as *const u8, simd.as_mut_ptr() as *mut u8, 16); + // Safety: copied 16 bytes of initialized memory into this, it is now initialized + simd.assume_init() +} + +/// Safety invariant: ptr must be valid for an aligned-for-u16x8 read of 16 bytes +#[allow(dead_code)] +#[inline(always)] +pub unsafe fn load8_aligned(ptr: *const u16) -> u16x8 { + *(ptr as *const u16x8) +} + +/// Safety invariant: ptr must be valid for an unaligned store of 16 bytes +#[inline(always)] +pub unsafe fn store8_unaligned(ptr: *mut u16, s: u16x8) { + ::core::ptr::copy_nonoverlapping(&s as *const u16x8 as *const u8, ptr as *mut u8, 16); +} + +/// Safety invariant: ptr must be valid for an aligned-for-u16x8 store of 16 bytes +#[allow(dead_code)] +#[inline(always)] +pub unsafe fn store8_aligned(ptr: *mut u16, s: u16x8) { + *(ptr as *mut u16x8) = s; +} + +cfg_if! { + if #[cfg(all(target_feature = "sse2", target_arch = "x86_64"))] { + use core::arch::x86_64::_mm_movemask_epi8; + use core::arch::x86_64::_mm_packus_epi16; + } else if #[cfg(all(target_feature = "sse2", target_arch = "x86"))] { + use core::arch::x86::_mm_movemask_epi8; + use core::arch::x86::_mm_packus_epi16; + } else if #[cfg(target_arch = "aarch64")]{ + use core::arch::aarch64::vmaxvq_u8; + use core::arch::aarch64::vmaxvq_u16; + } else { + + } +} + +// #[inline(always)] +// fn simd_byte_swap_u8(s: u8x16) -> u8x16 { +// unsafe { +// shuffle!(s, s, [1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14]) +// } +// } + +// #[inline(always)] +// pub fn simd_byte_swap(s: u16x8) -> u16x8 { +// to_u16_lanes(simd_byte_swap_u8(to_u8_lanes(s))) +// } + +#[inline(always)] +pub fn simd_byte_swap(s: u16x8) -> u16x8 { + let left = s << 8; + let right = s >> 8; + left | right +} + +#[inline(always)] +pub fn to_u16_lanes(s: u8x16) -> u16x8 { + u16x8::from_ne_bytes(s) +} + +cfg_if! { + if #[cfg(target_feature = "sse2")] { + + // Expose low-level mask instead of higher-level conclusion, + // because the non-ASCII case would perform less well otherwise. + // Safety-usable invariant: This returned value is whether each high bit is set + #[inline(always)] + pub fn mask_ascii(s: u8x16) -> i32 { + unsafe { + _mm_movemask_epi8(s.into()) + } + } + + } else { + + } +} + +cfg_if! { + if #[cfg(target_feature = "sse2")] { + #[inline(always)] + pub fn simd_is_ascii(s: u8x16) -> bool { + unsafe { + // Safety: We have cfg()d the correct platform + _mm_movemask_epi8(s.into()) == 0 + } + } + } else if #[cfg(target_arch = "aarch64")]{ + #[inline(always)] + pub fn simd_is_ascii(s: u8x16) -> bool { + unsafe { + // Safety: We have cfg()d the correct platform + vmaxvq_u8(s.into()) < 0x80 + } + } + } else { + #[inline(always)] + pub fn simd_is_ascii(s: u8x16) -> bool { + // This optimizes better on ARM than + // the lt formulation. + let highest_ascii = u8x16::splat(0x7F); + !any_mask8x16(s.simd_gt(highest_ascii)) + } + } +} + +cfg_if! { + if #[cfg(target_feature = "sse2")] { + #[inline(always)] + pub fn simd_is_str_latin1(s: u8x16) -> bool { + if simd_is_ascii(s) { + return true; + } + let above_str_latin1 = u8x16::splat(0xC4); + s.simd_lt(above_str_latin1).all() + } + } else if #[cfg(target_arch = "aarch64")]{ + #[inline(always)] + pub fn simd_is_str_latin1(s: u8x16) -> bool { + unsafe { + // Safety: We have cfg()d the correct platform + vmaxvq_u8(s.into()) < 0xC4 + } + } + } else { + #[inline(always)] + pub fn simd_is_str_latin1(s: u8x16) -> bool { + let above_str_latin1 = u8x16::splat(0xC4); + all_mask8x16(s.simd_lt(above_str_latin1)) + } + } +} + +cfg_if! { + if #[cfg(target_arch = "aarch64")]{ + #[inline(always)] + pub fn simd_is_basic_latin(s: u16x8) -> bool { + unsafe { + // Safety: We have cfg()d the correct platform + vmaxvq_u16(s.into()) < 0x80 + } + } + + #[inline(always)] + pub fn simd_is_latin1(s: u16x8) -> bool { + unsafe { + // Safety: We have cfg()d the correct platform + vmaxvq_u16(s.into()) < 0x100 + } + } + } else { + #[inline(always)] + pub fn simd_is_basic_latin(s: u16x8) -> bool { + let above_ascii = u16x8::splat(0x80); + all_mask16x8(s.simd_lt(above_ascii)) + } + + #[inline(always)] + pub fn simd_is_latin1(s: u16x8) -> bool { + // For some reason, on SSE2 this formulation + // seems faster in this case while the above + // function is better the other way round... + let highest_latin1 = u16x8::splat(0xFF); + !any_mask16x8(s.simd_gt(highest_latin1)) + } + } +} + +#[inline(always)] +pub fn contains_surrogates(s: u16x8) -> bool { + let mask = u16x8::splat(0xF800); + let surrogate_bits = u16x8::splat(0xD800); + any_mask16x8((s & mask).simd_eq(surrogate_bits)) +} + +cfg_if! { + if #[cfg(target_arch = "aarch64")]{ + macro_rules! aarch64_return_false_if_below_hebrew { + ($s:ident) => ({ + unsafe { + // Safety: We have cfg()d the correct platform + if vmaxvq_u16($s.into()) < 0x0590 { + return false; + } + } + }) + } + + macro_rules! non_aarch64_return_false_if_all { + ($s:ident) => () + } + } else { + macro_rules! aarch64_return_false_if_below_hebrew { + ($s:ident) => () + } + + macro_rules! non_aarch64_return_false_if_all { + ($s:ident) => ({ + if all_mask16x8($s) { + return false; + } + }) + } + } +} + +macro_rules! in_range16x8 { + ($s:ident, $start:expr, $end:expr) => {{ + // SIMD sub is wrapping + ($s - u16x8::splat($start)).simd_lt(u16x8::splat($end - $start)) + }}; +} + +#[inline(always)] +pub fn is_u16x8_bidi(s: u16x8) -> bool { + // We try to first quickly refute the RTLness of the vector. If that + // fails, we do the real RTL check, so in that case we end up wasting + // the work for the up-front quick checks. Even the quick-check is + // two-fold in order to return `false` ASAP if everything is below + // Hebrew. + + aarch64_return_false_if_below_hebrew!(s); + + let below_hebrew = s.simd_lt(u16x8::splat(0x0590)); + + non_aarch64_return_false_if_all!(below_hebrew); + + if all_mask16x8( + below_hebrew | in_range16x8!(s, 0x0900, 0x200F) | in_range16x8!(s, 0x2068, 0xD802), + ) { + return false; + } + + // Quick refutation failed. Let's do the full check. + + any_mask16x8( + (in_range16x8!(s, 0x0590, 0x0900) + | in_range16x8!(s, 0xFB1D, 0xFE00) + | in_range16x8!(s, 0xFE70, 0xFEFF) + | in_range16x8!(s, 0xD802, 0xD804) + | in_range16x8!(s, 0xD83A, 0xD83C) + | s.simd_eq(u16x8::splat(0x200F)) + | s.simd_eq(u16x8::splat(0x202B)) + | s.simd_eq(u16x8::splat(0x202E)) + | s.simd_eq(u16x8::splat(0x2067))), + ) +} + +#[inline(always)] +pub fn simd_unpack(s: u8x16) -> (u16x8, u16x8) { + let first: u8x16 = simd_swizzle!( + s, + u8x16::splat(0), + [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23] + ); + let second: u8x16 = simd_swizzle!( + s, + u8x16::splat(0), + [8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] + ); + (u16x8::from_ne_bytes(first), u16x8::from_ne_bytes(second)) +} + +cfg_if! { + if #[cfg(target_feature = "sse2")] { + #[inline(always)] + pub fn simd_pack(a: u16x8, b: u16x8) -> u8x16 { + unsafe { + // Safety: We have cfg()d the correct platform + _mm_packus_epi16(a.into(), b.into()).into() + } + } + } else { + #[inline(always)] + pub fn simd_pack(a: u16x8, b: u16x8) -> u8x16 { + let first: u8x16 = a.to_ne_bytes(); + let second: u8x16 = b.to_ne_bytes(); + simd_swizzle!( + first, + second, + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] + ) + } + } +} diff --git a/patch/prost-0.14.1/src/encoding/varint.rs b/patch/prost-0.14.1/src/encoding/varint.rs new file mode 100644 index 0000000..6f9f0f9 --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/varint.rs @@ -0,0 +1,667 @@ +#![allow(unsafe_op_in_unsafe_fn)] + +use ::bytes::{Buf, BufMut}; +use ::core::intrinsics::{assume, likely, unchecked_shl, unchecked_shr, unlikely}; + +use crate::error::DecodeError; + +/// ZigZag 编码 32 位整数 +#[inline(always)] +pub const fn encode_zigzag32(value: i32) -> u32 { + unsafe { (unchecked_shl(value, 1u8) ^ unchecked_shr(value, 31u8)) as u32 } +} + +/// ZigZag 解码 32 位整数 +#[inline(always)] +pub const fn decode_zigzag32(value: u32) -> i32 { + unsafe { (unchecked_shr(value, 1u8) as i32) ^ (-((value & 1) as i32)) } +} + +/// ZigZag 编码 64 位整数 +#[inline(always)] +pub const fn encode_zigzag64(value: i64) -> u64 { + unsafe { (unchecked_shl(value, 1u8) ^ unchecked_shr(value, 63u8)) as u64 } +} + +/// ZigZag 解码 64 位整数 +#[inline(always)] +pub const fn decode_zigzag64(value: u64) -> i64 { + unsafe { (unchecked_shr(value, 1u8) as i64) ^ (-((value & 1) as i64)) } +} + +/// The maximum number of bytes a Protobuf Varint can occupy. +const VARINT64_MAX_LEN: usize = 10; + +/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. +/// +/// Dispatches to a fast path if the buffer has enough contiguous space, +/// otherwise falls back to a slower, byte-by-byte write. +#[inline] +pub fn encode_varint64(value: u64, buf: &mut impl BufMut) -> usize { + let len = encoded_len_varint64(value); + + // If there is enough contiguous space, use the optimized path. + if likely(buf.chunk_mut().len() >= len) { + // Safety: The check above guarantees `buf.chunk_mut()` has at least `len` bytes. + unsafe { encode_varint64_fast(value, len, buf) }; + } else { + encode_varint64_slow(value, len, buf); + } + + len +} + +/// Fast-path for encoding to a contiguous buffer slice. +/// +/// ## Safety +/// +/// The caller must ensure `buf.chunk_mut().len() >= len`. +#[inline(always)] +unsafe fn encode_varint64_fast(mut value: u64, len: usize, buf: &mut impl BufMut) { + let ptr = buf.chunk_mut().as_mut_ptr(); + + for i in 0..(len - 1) { + *ptr.add(i) = (value & 0x7F) as u8 | 0x80; + value >>= 7; + } + + // After the loop, `value` holds the last byte, which must not have the continuation bit. + // The `encoded_len_varint` logic guarantees this. + assume(value < 0x80); + *ptr.add(len - 1) = value as u8; + + // Notify the buffer that `len` bytes have been written. + buf.advance_mut(len); +} + +/// Slow-path encoding for buffers that may not be contiguous. +#[cold] +#[inline(never)] +fn encode_varint64_slow(mut value: u64, len: usize, buf: &mut impl BufMut) { + for _ in 0..(len - 1) { + buf.put_u8((value & 0x7F) as u8 | 0x80); + value >>= 7; + } + // After the loop, `value` holds the last byte, which must not have the continuation bit. + // The `encoded_len_varint` logic guarantees this. + unsafe { assume(value < 0x80) }; + + buf.put_u8(value as u8); +} + +/// Returns the encoded length of the value in LEB128 variable length format. +/// The returned value will be between 1 and 10, inclusive. +#[inline] +pub const fn encoded_len_varint64(value: u64) -> usize { + unsafe { + let value = value + .bit_width() + .unchecked_mul(9) + .unbounded_shr(6) + .unchecked_add(1); + assume(value >= 1 && value <= VARINT64_MAX_LEN as u32); + value as usize + } +} + +/// Decodes a LEB128-encoded variable length integer from the buffer. +#[inline] +pub fn decode_varint64(buf: &mut impl Buf) -> Result { + fn inner(buf: &mut impl Buf) -> Option { + let bytes = buf.chunk(); + let len = bytes.len(); + if unlikely(len == 0) { + return None; + } + + // Fast path for single-byte varints. + let first = unsafe { *bytes.get_unchecked(0) }; + if likely(first < 0x80) { + buf.advance(1); + return Some(first as _); + } + + // If the chunk is large enough or the varint is known to terminate within it, + // use the fast path which operates on a slice. + if likely(len >= VARINT64_MAX_LEN || bytes[len - 1] < 0x80) { + return decode_varint64_fast(bytes).map(|(value, advance)| { + buf.advance(advance); + value + }); + } + + // Fallback for varints that cross chunk boundaries. + decode_varint64_slow(buf) + } + inner(buf).ok_or(DecodeError::new("invalid varint64")) +} + +/// Fast-path decoding of a varint from a contiguous memory slice. +/// +/// ## Safety +/// +/// Assumes `bytes` contains a complete varint or is at least `VARINT64_MAX_LEN` bytes long. +#[inline(always)] +fn decode_varint64_fast(bytes: &[u8]) -> Option<(u64, usize)> { + let ptr = bytes.as_ptr(); + let mut value = 0u64; + + for i in 0..VARINT64_MAX_LEN { + let byte = unsafe { *ptr.add(i) }; + value |= ((byte & 0x7F) as u64) << (i * 7); + + if byte < 0x80 { + // Check for overlong encoding on the 10th byte. + if unlikely(i == 9 && byte > 1) { + return None; + } + return Some((value, i + 1)); + } + } + + // A varint must not be longer than 10 bytes. + None +} + +/// Slow-path decoding for varints that may cross `Buf` chunk boundaries. +#[cold] +#[inline(never)] +fn decode_varint64_slow(buf: &mut impl Buf) -> Option { + // Safety: The dispatcher `decode_varint` only calls this function if `bytes[0] >= 0x80`. + // This hint allows the compiler to optimize the first loop iteration. + unsafe { assume(buf.chunk().len() > 0 && buf.chunk()[0] >= 0x80) }; + + let mut value = 0u64; + for i in 0..VARINT64_MAX_LEN { + if unlikely(!buf.has_remaining()) { + return None; // Unexpected end of buffer. + } + let byte = buf.get_u8(); + value |= ((byte & 0x7F) as u64) << (i * 7); + + if byte < 0x80 { + // Check for overlong encoding on the 10th byte. + if unlikely(i == 9 && byte > 1) { + return None; + } + return Some(value); + } + } + + // A varint must not be longer than 10 bytes. + None +} + +/// The maximum number of bytes a Protobuf Varint can occupy. +const VARINT32_MAX_LEN: usize = 5; + +/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. +/// +/// Dispatches to a fast path if the buffer has enough contiguous space, +/// otherwise falls back to a slower, byte-by-byte write. +#[inline] +pub fn encode_varint32(value: u32, buf: &mut impl BufMut) -> usize { + let len = encoded_len_varint32(value); + + // If there is enough contiguous space, use the optimized path. + if likely(buf.chunk_mut().len() >= len) { + // Safety: The check above guarantees `buf.chunk_mut()` has at least `len` bytes. + unsafe { encode_varint32_fast(value, len, buf) }; + } else { + encode_varint32_slow(value, len, buf); + } + + len +} + +/// Fast-path for encoding to a contiguous buffer slice. +/// +/// ## Safety +/// +/// The caller must ensure `buf.chunk_mut().len() >= len`. +#[inline(always)] +unsafe fn encode_varint32_fast(mut value: u32, len: usize, buf: &mut impl BufMut) { + let ptr = buf.chunk_mut().as_mut_ptr(); + + for i in 0..(len - 1) { + *ptr.add(i) = (value & 0x7F) as u8 | 0x80; + value >>= 7; + } + + // After the loop, `value` holds the last byte, which must not have the continuation bit. + // The `encoded_len_varint` logic guarantees this. + assume(value < 0x80); + *ptr.add(len - 1) = value as u8; + + // Notify the buffer that `len` bytes have been written. + buf.advance_mut(len); +} + +/// Slow-path encoding for buffers that may not be contiguous. +#[cold] +#[inline(never)] +fn encode_varint32_slow(mut value: u32, len: usize, buf: &mut impl BufMut) { + for _ in 0..(len - 1) { + buf.put_u8((value & 0x7F) as u8 | 0x80); + value >>= 7; + } + // After the loop, `value` holds the last byte, which must not have the continuation bit. + // The `encoded_len_varint` logic guarantees this. + unsafe { assume(value < 0x80) }; + + buf.put_u8(value as u8); +} + +/// Returns the encoded length of the value in LEB128 variable length format. +/// The returned value will be between 1 and 5, inclusive. +#[inline] +pub const fn encoded_len_varint32(value: u32) -> usize { + unsafe { + let value = value + .bit_width() + .unchecked_mul(9) + .unbounded_shr(6) + .unchecked_add(1); + assume(value >= 1 && value <= VARINT32_MAX_LEN as u32); + value as usize + } +} + +/// Decodes a LEB128-encoded variable length integer from the buffer. +#[inline] +pub fn decode_varint32(buf: &mut impl Buf) -> Result { + #[inline(always)] + fn inner(buf: &mut impl Buf) -> Option { + let bytes = buf.chunk(); + let len = bytes.len(); + if unlikely(len == 0) { + return None; + } + + // Fast path for single-byte varints. + let first = unsafe { *bytes.get_unchecked(0) }; + if likely(first < 0x80) { + buf.advance(1); + return Some(first as _); + } + + // If the chunk is large enough or the varint is known to terminate within it, + // use the fast path which operates on a slice. + if likely(len >= VARINT32_MAX_LEN || bytes[len - 1] < 0x80) { + return decode_varint32_fast(bytes).map(|(value, advance)| { + buf.advance(advance); + value + }); + } + + // Fallback for varints that cross chunk boundaries. + decode_varint32_slow(buf) + } + inner(buf).ok_or(DecodeError::new("invalid varint32")) +} + +/// Fast-path decoding of a varint from a contiguous memory slice. +/// +/// ## Safety +/// +/// Assumes `bytes` contains a complete varint or is at least `VARINT32_MAX_LEN` bytes long. +#[inline(always)] +fn decode_varint32_fast(bytes: &[u8]) -> Option<(u32, usize)> { + let ptr = bytes.as_ptr(); + let mut value = 0u32; + + for i in 0..VARINT32_MAX_LEN { + let byte = unsafe { *ptr.add(i) }; + value |= ((byte & 0x7F) as u32) << (i * 7); + + if byte < 0x80 { + // Check for overlong encoding on the 5th byte. + if unlikely(i == 4 && byte > 4) { + return None; + } + return Some((value, i + 1)); + } + } + + // A varint must not be longer than 5 bytes. + None +} + +/// Slow-path decoding for varints that may cross `Buf` chunk boundaries. +#[cold] +#[inline(never)] +fn decode_varint32_slow(buf: &mut impl Buf) -> Option { + // Safety: The dispatcher `decode_varint` only calls this function if `bytes[0] >= 0x80`. + // This hint allows the compiler to optimize the first loop iteration. + unsafe { assume(buf.chunk().len() > 0 && buf.chunk()[0] >= 0x80) }; + + let mut value = 0u32; + for i in 0..VARINT32_MAX_LEN { + if unlikely(!buf.has_remaining()) { + return None; // Unexpected end of buffer. + } + let byte = buf.get_u8(); + value |= ((byte & 0x7F) as u32) << (i * 7); + + if byte < 0x80 { + // Check for overlong encoding on the 5th byte. + if unlikely(i == 4 && byte > 4) { + return None; + } + return Some(value); + } + } + + // A varint must not be longer than 5 bytes. + None +} + +pub mod usize { + use super::*; + + #[cfg(target_pointer_width = "32")] + pub(super) use super::VARINT32_MAX_LEN as VARINT_MAX_LEN; + #[cfg(target_pointer_width = "64")] + pub(super) use super::VARINT64_MAX_LEN as VARINT_MAX_LEN; + + #[inline(always)] + pub fn encode_varint(value: usize, buf: &mut impl BufMut) -> usize { + #[cfg(target_pointer_width = "32")] + return encode_varint32(value as u32, buf); + #[cfg(target_pointer_width = "64")] + return encode_varint64(value as u64, buf); + } + + #[inline(always)] + pub const fn encoded_len_varint(value: usize) -> usize { + #[cfg(target_pointer_width = "32")] + return encoded_len_varint32(value as u32); + #[cfg(target_pointer_width = "64")] + return encoded_len_varint64(value as u64); + } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result { + #[cfg(target_pointer_width = "32")] + return transmute_unchecked!(decode_varint32(buf)); + #[cfg(target_pointer_width = "64")] + return transmute_unchecked!(decode_varint64(buf)); + } +} + +pub mod bool { + use super::*; + + #[inline(always)] + pub fn encode_varint(value: bool, buf: &mut impl BufMut) -> usize { + buf.put_u8(value as _); + 1 + } + + #[inline(always)] + pub const fn encoded_len_varint(_value: bool) -> usize { 1 } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result { + fn inner(buf: &mut impl Buf) -> Option { + if unlikely(buf.remaining() == 0) { + return None; + } + let byte = buf.get_u8(); + if byte <= 1 { Some(byte != 0) } else { None } + } + inner(buf).ok_or(DecodeError::new("invalid bool")) + } + + #[inline] + pub(in super::super) fn encode_packed_fast(values: &[bool], buf: &mut B) { + let start_ptr = buf.as_mut().as_mut_ptr(); + buf.reserve(usize::VARINT_MAX_LEN); + unsafe { + buf.set_len(buf.len() + usize::VARINT_MAX_LEN); + } + + let mut length = 0; + for &value in values { + length += encode_varint(value, buf); + } + let mut length_slice = unsafe { + &mut *(start_ptr as *mut [::core::mem::MaybeUninit; usize::VARINT_MAX_LEN]) + as &mut [::core::mem::MaybeUninit] + }; + let len = usize::encode_varint(length, &mut length_slice); + + unsafe { + let dst = start_ptr.add(len); + let src = start_ptr.add(usize::VARINT_MAX_LEN); + ::core::ptr::copy(src, dst, length); + buf.set_len( + buf.len() + .unchecked_sub(usize::VARINT_MAX_LEN) + .unchecked_add(len), + ); + } + } +} + +macro_rules! varint { + ($ty:ty, $proto_ty:ident,32) => { + pub mod $proto_ty { + use super::*; + + #[inline(always)] + pub fn encode_varint(value: $ty, buf: &mut impl BufMut) -> usize { encode_varint32(value as u32, buf) } + + #[inline(always)] + pub const fn encoded_len_varint(value: $ty) -> usize { encoded_len_varint32(value as u32) } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result<$ty, DecodeError> { + transmute_unchecked!(decode_varint32(buf)) + } + + #[inline] + pub(in super::super) fn encode_packed_fast(values: &[$ty], buf: &mut impl ReservableBuf) { + let start_ptr = buf.as_mut().as_mut_ptr(); + buf.reserve(usize::VARINT_MAX_LEN); + unsafe { + buf.set_len(buf.len() + usize::VARINT_MAX_LEN); + } + + let mut length = 0; + for &value in values { + length += encode_varint(value, buf); + } + let mut length_slice = unsafe { + &mut *(start_ptr as *mut [::core::mem::MaybeUninit; usize::VARINT_MAX_LEN]) + as &mut [::core::mem::MaybeUninit] + }; + let len = usize::encode_varint(length, &mut length_slice); + + unsafe { + let dst = start_ptr.add(len); + let src = start_ptr.add(usize::VARINT_MAX_LEN); + ::core::ptr::copy(src, dst, length); + buf.set_len( + buf.len() + .unchecked_sub(usize::VARINT_MAX_LEN) + .unchecked_add(len), + ); + } + } + } + }; + ($ty:ty, $proto_ty:ident,64) => { + pub mod $proto_ty { + use super::*; + + #[inline(always)] + pub fn encode_varint(value: $ty, buf: &mut impl BufMut) -> usize { encode_varint64(value as u64, buf) } + + #[inline(always)] + pub const fn encoded_len_varint(value: $ty) -> usize { encoded_len_varint64(value as u64) } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result<$ty, DecodeError> { + transmute_unchecked!(decode_varint64(buf)) + } + + #[inline] + pub(in super::super) fn encode_packed_fast(values: &[$ty], buf: &mut impl ReservableBuf) { + let start_ptr = buf.as_mut().as_mut_ptr(); + buf.reserve(usize::VARINT_MAX_LEN); + unsafe { + buf.set_len(buf.len() + usize::VARINT_MAX_LEN); + } + + let mut length = 0; + for &value in values { + length += encode_varint(value, buf); + } + let mut length_slice = unsafe { + &mut *(start_ptr as *mut [::core::mem::MaybeUninit; usize::VARINT_MAX_LEN]) + as &mut [::core::mem::MaybeUninit] + }; + let len = usize::encode_varint(length, &mut length_slice); + + unsafe { + let dst = start_ptr.add(len); + let src = start_ptr.add(usize::VARINT_MAX_LEN); + ::core::ptr::copy(src, dst, length); + buf.set_len( + buf.len() + .unchecked_sub(usize::VARINT_MAX_LEN) + .unchecked_add(len), + ); + } + } + } + }; + ($ty:ty, $proto_ty:ident,32, $encode_fn:ident, $decode_fn:ident) => { + pub mod $proto_ty { + use super::*; + + #[inline(always)] + pub fn encode_varint(value: $ty, buf: &mut impl BufMut) -> usize { + encode_varint32($encode_fn(value), buf) + } + + #[inline(always)] + pub const fn encoded_len_varint(value: $ty) -> usize { encoded_len_varint32($encode_fn(value)) } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result<$ty, DecodeError> { + decode_varint32(buf).map($decode_fn) + } + + #[inline] + pub(in super::super) fn encode_packed_fast(values: &[$ty], buf: &mut impl ReservableBuf) { + let start_ptr = buf.as_mut().as_mut_ptr(); + buf.reserve(usize::VARINT_MAX_LEN); + unsafe { + buf.set_len(buf.len() + usize::VARINT_MAX_LEN); + } + + let mut length = 0; + for &value in values { + length += encode_varint(value, buf); + } + let mut length_slice = unsafe { + &mut *(start_ptr as *mut [::core::mem::MaybeUninit; usize::VARINT_MAX_LEN]) + as &mut [::core::mem::MaybeUninit] + }; + let len = usize::encode_varint(length, &mut length_slice); + + unsafe { + let dst = start_ptr.add(len); + let src = start_ptr.add(usize::VARINT_MAX_LEN); + ::core::ptr::copy(src, dst, length); + buf.set_len( + buf.len() + .unchecked_sub(usize::VARINT_MAX_LEN) + .unchecked_add(len), + ); + } + } + } + }; + ($ty:ty, $proto_ty:ident,64, $encode_fn:ident, $decode_fn:ident) => { + pub mod $proto_ty { + use super::*; + + #[inline(always)] + pub fn encode_varint(value: $ty, buf: &mut impl BufMut) -> usize { + encode_varint64($encode_fn(value), buf) + } + + #[inline(always)] + pub const fn encoded_len_varint(value: $ty) -> usize { encoded_len_varint64($encode_fn(value)) } + + #[inline(always)] + pub fn decode_varint(buf: &mut impl Buf) -> Result<$ty, DecodeError> { + decode_varint64(buf).map($decode_fn) + } + + #[inline] + pub(in super::super) fn encode_packed_fast(values: &[$ty], buf: &mut impl ReservableBuf) { + let start_ptr = buf.as_mut().as_mut_ptr(); + buf.reserve(usize::VARINT_MAX_LEN); + unsafe { + buf.set_len(buf.len() + usize::VARINT_MAX_LEN); + } + + let mut length = 0; + for &value in values { + length += encode_varint(value, buf); + } + let mut length_slice = unsafe { + &mut *(start_ptr as *mut [::core::mem::MaybeUninit; usize::VARINT_MAX_LEN]) + as &mut [::core::mem::MaybeUninit] + }; + let len = usize::encode_varint(length, &mut length_slice); + + unsafe { + let dst = start_ptr.add(len); + let src = start_ptr.add(usize::VARINT_MAX_LEN); + ::core::ptr::copy(src, dst, length); + buf.set_len( + buf.len() + .unchecked_sub(usize::VARINT_MAX_LEN) + .unchecked_add(len), + ); + } + } + } + }; +} + +varint!(i32, int32, 32); +varint!(i64, int64, 64); +varint!(u32, uint32, 32); +varint!(u64, uint64, 64); +varint!(i32, sint32, 32, encode_zigzag32, decode_zigzag32); +varint!(i64, sint64, 64, encode_zigzag64, decode_zigzag64); + +pub(super) trait ReservableBuf: Sized + BufMut + AsMut<[u8]> { + fn reserve(&mut self, additional: usize); + fn len(&self) -> usize; + unsafe fn set_len(&mut self, len: usize); +} + +impl ReservableBuf for ::bytes::BytesMut { + #[inline(always)] + fn reserve(&mut self, additional: usize) { Self::reserve(self, additional); } + #[inline(always)] + fn len(&self) -> usize { Self::len(self) } + #[inline(always)] + unsafe fn set_len(&mut self, len: usize) { Self::set_len(self, len); } +} + +impl ReservableBuf for ::alloc::vec::Vec { + #[inline(always)] + fn reserve(&mut self, additional: usize) { Self::reserve(self, additional); } + #[inline(always)] + fn len(&self) -> usize { Self::len(self) } + #[inline(always)] + unsafe fn set_len(&mut self, len: usize) { Self::set_len(self, len); } +} diff --git a/patch/prost-0.14.1/src/encoding/wire_type.rs b/patch/prost-0.14.1/src/encoding/wire_type.rs new file mode 100644 index 0000000..291c2fe --- /dev/null +++ b/patch/prost-0.14.1/src/encoding/wire_type.rs @@ -0,0 +1,70 @@ +use alloc::format; + +use crate::DecodeError; + +/// Represent the wire type for protobuf encoding. +/// +/// The integer value is equvilant with the encoded value. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum WireType { + Varint = 0, + SixtyFourBit = 1, + LengthDelimited = 2, + StartGroup = 3, + EndGroup = 4, + ThirtyTwoBit = 5, +} + +impl WireType { + #[inline] + const fn try_from(value: u8) -> Option { + match value { + 0 => Some(WireType::Varint), + 1 => Some(WireType::SixtyFourBit), + 2 => Some(WireType::LengthDelimited), + 3 => Some(WireType::StartGroup), + 4 => Some(WireType::EndGroup), + 5 => Some(WireType::ThirtyTwoBit), + _ => None, + } + } + + #[inline] + pub fn try_from_tag(tag: u32) -> Result<(Self, u32), DecodeError> { + let value = (tag & super::WireTypeMask) as u8; + match Self::try_from(value) { + Some(wire_type) => Ok((wire_type, tag >> super::WireTypeBits)), + None => Err(DecodeError::new(format!("invalid wire type value: {value}"))), + } + } +} + +impl TryFrom for WireType { + type Error = DecodeError; + + #[inline] + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(WireType::Varint), + 1 => Ok(WireType::SixtyFourBit), + 2 => Ok(WireType::LengthDelimited), + 3 => Ok(WireType::StartGroup), + 4 => Ok(WireType::EndGroup), + 5 => Ok(WireType::ThirtyTwoBit), + _ => Err(DecodeError::new(format!("invalid wire type value: {value}"))), + } + } +} + +/// Checks that the expected wire type matches the actual wire type, +/// or returns an error result. +#[inline] +pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> { + if expected != actual { + return Err(DecodeError::new(format!( + "invalid wire type: {actual:?} (expected {expected:?})", + ))); + } + Ok(()) +} diff --git a/patch/prost-0.14.1/src/error.rs b/patch/prost-0.14.1/src/error.rs new file mode 100644 index 0000000..b461785 --- /dev/null +++ b/patch/prost-0.14.1/src/error.rs @@ -0,0 +1,180 @@ +//! Protobuf encoding and decoding errors. + +use alloc::borrow::Cow; +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +use core::fmt; + +/// A Protobuf message decoding error. +/// +/// `DecodeError` indicates that the input buffer does not contain a valid +/// Protobuf message. The error details should be considered 'best effort': in +/// general it is not possible to exactly pinpoint why data is malformed. +#[derive(Clone, PartialEq, Eq)] +pub struct DecodeError { + inner: Box, +} + +#[derive(Clone, PartialEq, Eq)] +struct Inner { + /// A 'best effort' root cause description. + description: Cow<'static, str>, + /// A stack of (message, field) name pairs, which identify the specific + /// message type and field where decoding failed. The stack contains an + /// entry per level of nesting. + stack: Vec<(&'static str, &'static str)>, +} + +impl DecodeError { + /// Creates a new `DecodeError` with a 'best effort' root cause description. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + #[cold] + pub fn new(description: impl Into>) -> DecodeError { + DecodeError { + inner: Box::new(Inner { + description: description.into(), + stack: Vec::new(), + }), + } + } + + /// Pushes a (message, field) name location pair on to the location stack. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + pub fn push(&mut self, message: &'static str, field: &'static str) { + self.inner.stack.push((message, field)); + } +} + +impl fmt::Debug for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DecodeError") + .field("description", &self.inner.description) + .field("stack", &self.inner.stack) + .finish() + } +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("failed to decode Protobuf message: ")?; + for &(message, field) in &self.inner.stack { + write!(f, "{}.{}: ", message, field)?; + } + f.write_str(&self.inner.description) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for DecodeError {} + +#[cfg(feature = "std")] +impl From for std::io::Error { + fn from(error: DecodeError) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidData, error) + } +} + +/// A Protobuf message encoding error. +/// +/// `EncodeError` always indicates that a message failed to encode because the +/// provided buffer had insufficient capacity. Message encoding is otherwise +/// infallible. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct EncodeError { + required: usize, + remaining: usize, +} + +impl EncodeError { + /// Creates a new `EncodeError`. + pub(crate) fn new(required: usize, remaining: usize) -> EncodeError { + EncodeError { + required, + remaining, + } + } + + /// Returns the required buffer capacity to encode the message. + pub fn required_capacity(&self) -> usize { + self.required + } + + /// Returns the remaining length in the provided buffer at the time of encoding. + pub fn remaining(&self) -> usize { + self.remaining + } +} + +impl fmt::Display for EncodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "failed to encode Protobuf message; insufficient buffer capacity (required: {}, remaining: {})", + self.required, self.remaining + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for EncodeError {} + +#[cfg(feature = "std")] +impl From for std::io::Error { + fn from(error: EncodeError) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidInput, error) + } +} + +/// An error indicating that an unknown enumeration value was encountered. +/// +/// The Protobuf spec mandates that enumeration value sets are ‘open’, so this +/// error's value represents an integer value unrecognized by the +/// presently used enum definition. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct UnknownEnumValue(pub i32); + +impl fmt::Display for UnknownEnumValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "unknown enumeration value {}", self.0) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for UnknownEnumValue {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_push() { + let mut decode_error = DecodeError::new("something failed"); + decode_error.push("Foo bad", "bar.foo"); + decode_error.push("Baz bad", "bar.baz"); + + assert_eq!( + decode_error.to_string(), + "failed to decode Protobuf message: Foo bad.bar.foo: Baz bad.bar.baz: something failed" + ); + } + + #[cfg(feature = "std")] + #[test] + fn test_into_std_io_error() { + let decode_error = DecodeError::new("something failed"); + let std_io_error = std::io::Error::from(decode_error); + + assert_eq!(std_io_error.kind(), std::io::ErrorKind::InvalidData); + assert_eq!( + std_io_error.to_string(), + "failed to decode Protobuf message: something failed" + ); + } +} diff --git a/patch/prost-0.14.1/src/lib.rs b/patch/prost-0.14.1/src/lib.rs new file mode 100644 index 0000000..72d97d9 --- /dev/null +++ b/patch/prost-0.14.1/src/lib.rs @@ -0,0 +1,54 @@ +#![allow(internal_features, unsafe_op_in_unsafe_fn)] +#![feature(core_intrinsics, uint_bit_width, portable_simd, pattern, char_internals)] +#![doc(html_root_url = "https://docs.rs/prost/0.14.1")] +#![cfg_attr(not(feature = "std"), no_std)] +#![doc = include_str!("../README.md")] + +// Re-export the alloc crate for use within derived code. +#[doc(hidden)] +pub extern crate alloc; + +// Re-export the bytes crate for use within derived code. +pub use bytes; + +// Re-export the alloc crate for use within derived code. +#[cfg(feature = "indexmap")] +#[doc(hidden)] +pub use indexmap; + +mod error; +mod message; +// mod name; +mod types; +mod byte_str; + +#[doc(hidden)] +pub mod encoding; + +pub use crate::encoding::length_delimiter::{ + decode_length_delimiter, encode_length_delimiter, length_delimiter_len, +}; +pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue}; +pub use crate::message::Message; +// pub use crate::name::Name; +pub use crate::byte_str::ByteStr; + +// See `encoding::DecodeContext` for more info. +// 100 is the default recursion limit in the C++ implementation. +#[cfg(not(feature = "no-recursion-limit"))] +const RECURSION_LIMIT: u32 = 100; + +// Re-export #[derive(Message, Enumeration, Oneof)]. +// Based on serde's equivalent re-export [1], but enabled by default. +// +// [1]: https://github.com/serde-rs/serde/blob/v1.0.89/serde/src/lib.rs#L245-L256 +#[cfg(feature = "derive")] +#[allow(unused_imports)] +#[macro_use] +extern crate prost_derive; +#[cfg(feature = "derive")] +#[doc(hidden)] +pub use prost_derive::*; + +#[macro_use] +extern crate macros; diff --git a/patch/prost-0.14.1/src/message.rs b/patch/prost-0.14.1/src/message.rs new file mode 100644 index 0000000..ee364d8 --- /dev/null +++ b/patch/prost-0.14.1/src/message.rs @@ -0,0 +1,184 @@ +use core::num::NonZeroU32; + +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +use bytes::{Buf, BufMut}; + +use crate::{ + DecodeError, EncodeError, + encoding::{ + DecodeContext, decode_tag, message, + varint::usize::{encode_varint, encoded_len_varint}, + wire_type::WireType, + }, +}; + +/// A Protocol Buffers message. +pub trait Message: Send + Sync { + /// Encodes the message to a buffer. + /// + /// This method will panic if the buffer has insufficient capacity. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + fn encode_raw(&self, buf: &mut impl BufMut) + where + Self: Sized; + + /// Decodes a field from a buffer, and merges it into `self`. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + Self: Sized; + + /// Returns the encoded length of the message without a length delimiter. + fn encoded_len(&self) -> usize; + + /// Encodes the message to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> + where + Self: Sized, + { + let required = self.encoded_len(); + let remaining = buf.remaining_mut(); + if required > remaining { + return Err(EncodeError::new(required, remaining)); + } + + self.encode_raw(buf); + Ok(()) + } + + /// Encodes the message to a newly allocated buffer. + fn encode_to_vec(&self) -> Vec + where + Self: Sized, + { + let mut buf = Vec::with_capacity(self.encoded_len()); + + self.encode_raw(&mut buf); + buf + } + + /// Encodes the message with a length-delimiter to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode_length_delimited(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> + where + Self: Sized, + { + let len = self.encoded_len(); + let required = len + encoded_len_varint(len); + let remaining = buf.remaining_mut(); + if required > remaining { + return Err(EncodeError::new(required, remaining)); + } + encode_varint(len, buf); + self.encode_raw(buf); + Ok(()) + } + + /// Encodes the message with a length-delimiter to a newly allocated buffer. + fn encode_length_delimited_to_vec(&self) -> Vec + where + Self: Sized, + { + let len = self.encoded_len(); + let mut buf = Vec::with_capacity(len + encoded_len_varint(len)); + + encode_varint(len, &mut buf); + self.encode_raw(&mut buf); + buf + } + + /// Decodes an instance of the message from a buffer. + /// + /// The entire buffer will be consumed. + fn decode(mut buf: impl Buf) -> Result + where + Self: Default, + { + let mut message = Self::default(); + Self::merge(&mut message, &mut buf).map(|_| message) + } + + /// Decodes a length-delimited instance of the message from the buffer. + fn decode_length_delimited(buf: impl Buf) -> Result + where + Self: Default, + { + let mut message = Self::default(); + message.merge_length_delimited(buf)?; + Ok(message) + } + + /// Decodes an instance of the message from a buffer, and merges it into `self`. + /// + /// The entire buffer will be consumed. + fn merge(&mut self, mut buf: impl Buf) -> Result<(), DecodeError> + where + Self: Sized, + { + let ctx = DecodeContext::default(); + while buf.has_remaining() { + let (number, wire_type) = decode_tag(&mut buf)?; + self.merge_field(number, wire_type, &mut buf, ctx.clone())?; + } + Ok(()) + } + + /// Decodes a length-delimited instance of the message from buffer, and + /// merges it into `self`. + fn merge_length_delimited(&mut self, mut buf: impl Buf) -> Result<(), DecodeError> + where + Self: Sized, + { + message::merge( + WireType::LengthDelimited, + self, + &mut buf, + DecodeContext::default(), + ) + } + + /// Clears the message, resetting all fields to their default. + fn clear(&mut self); +} + +impl Message for Box +where + M: Message, +{ + fn encode_raw(&self, buf: &mut impl BufMut) { (**self).encode_raw(buf) } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + (**self).merge_field(number, wire_type, buf, ctx) + } + fn encoded_len(&self) -> usize { (**self).encoded_len() } + fn clear(&mut self) { (**self).clear() } +} + +#[cfg(test)] +mod tests { + use super::*; + + const _MESSAGE_IS_OBJECT_SAFE: Option<&dyn Message> = None; +} diff --git a/patch/prost-0.14.1/src/name.rs b/patch/prost-0.14.1/src/name.rs new file mode 100644 index 0000000..1b5b4a2 --- /dev/null +++ b/patch/prost-0.14.1/src/name.rs @@ -0,0 +1,34 @@ +//! Support for associating type name information with a [`Message`]. + +use crate::Message; + +#[cfg(not(feature = "std"))] +use alloc::{format, string::String}; + +/// Associate a type name with a [`Message`] type. +pub trait Name: Message { + /// Simple name for this [`Message`]. + /// This name is the same as it appears in the source .proto file, e.g. `FooBar`. + const NAME: &'static str; + + /// Package name this message type is contained in. They are domain-like + /// and delimited by `.`, e.g. `google.protobuf`. + const PACKAGE: &'static str; + + /// Fully-qualified unique name for this [`Message`]. + /// It's prefixed with the package name and names of any parent messages, + /// e.g. `google.rpc.BadRequest.FieldViolation`. + /// By default, this is the package name followed by the message name. + /// Fully-qualified names must be unique within a domain of Type URLs. + fn full_name() -> String { + format!("{}.{}", Self::PACKAGE, Self::NAME) + } + + /// Type URL for this [`Message`], which by default is the full name with a + /// leading slash, but may also include a leading domain name, e.g. + /// `type.googleapis.com/google.profile.Person`. + /// This can be used when serializing into the `google.protobuf.Any` type. + fn type_url() -> String { + format!("/{}", Self::full_name()) + } +} diff --git a/patch/prost-0.14.1/src/types.rs b/patch/prost-0.14.1/src/types.rs new file mode 100644 index 0000000..b7edc52 --- /dev/null +++ b/patch/prost-0.14.1/src/types.rs @@ -0,0 +1,573 @@ +//! Protocol Buffers well-known wrapper types. +//! +//! This module provides implementations of `Message` for Rust standard library types which +//! correspond to a Protobuf well-known wrapper type. The remaining well-known types are defined in +//! the `prost-types` crate in order to avoid a cyclic dependency between `prost` and +//! `prost-build`. + +use core::num::NonZeroU32; + +// use alloc::format; +use alloc::string::String; +use alloc::vec::Vec; + +use ::bytes::{Buf, BufMut, Bytes}; + +use crate::encoding::wire_type::WireType; +use crate::encoding::FieldNumber1; +use crate::{ + encoding::{ + bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64, DecodeContext, + }, + DecodeError, Message, +}; + +/// `google.protobuf.BoolValue` +impl Message for bool { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self { + bool::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + bool::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self { + 2 + } else { + 0 + } + } + fn clear(&mut self) { + *self = false; + } +} + +// /// `google.protobuf.BoolValue` +// impl Name for bool { +// const NAME: &'static str = "BoolValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.UInt32Value` +impl Message for u32 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0 { + uint32::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + uint32::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + uint32::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +// /// `google.protobuf.UInt32Value` +// impl Name for u32 { +// const NAME: &'static str = "UInt32Value"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.UInt64Value` +impl Message for u64 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0 { + uint64::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + uint64::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + uint64::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +// /// `google.protobuf.UInt64Value` +// impl Name for u64 { +// const NAME: &'static str = "UInt64Value"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.Int32Value` +impl Message for i32 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0 { + int32::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + int32::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + int32::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +// /// `google.protobuf.Int32Value` +// impl Name for i32 { +// const NAME: &'static str = "Int32Value"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.Int64Value` +impl Message for i64 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0 { + int64::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + int64::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + int64::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +// /// `google.protobuf.Int64Value` +// impl Name for i64 { +// const NAME: &'static str = "Int64Value"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.FloatValue` +impl Message for f32 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0.0 { + float::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + float::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0.0 { + float::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0.0; + } +} + +// /// `google.protobuf.FloatValue` +// impl Name for f32 { +// const NAME: &'static str = "FloatValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.DoubleValue` +impl Message for f64 { + fn encode_raw(&self, buf: &mut impl BufMut) { + if *self != 0.0 { + double::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + double::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0.0 { + double::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0.0; + } +} + +// /// `google.protobuf.DoubleValue` +// impl Name for f64 { +// const NAME: &'static str = "DoubleValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.StringValue` +impl Message for String { + fn encode_raw(&self, buf: &mut impl BufMut) { + if !self.is_empty() { + string::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + string::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + string::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +// /// `google.protobuf.StringValue` +// impl Name for String { +// const NAME: &'static str = "StringValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.BytesValue` +impl Message for Vec { + fn encode_raw(&self, buf: &mut impl BufMut) { + if !self.is_empty() { + bytes::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + bytes::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + bytes::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +// /// `google.protobuf.BytesValue` +// impl Name for Vec { +// const NAME: &'static str = "BytesValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.BytesValue` +impl Message for Bytes { + fn encode_raw(&self, buf: &mut impl BufMut) { + if !self.is_empty() { + bytes::encode(FieldNumber1, self, buf) + } + } + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + if number == FieldNumber1 { + bytes::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, number, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + bytes::encoded_len(FieldNumber1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +// /// `google.protobuf.BytesValue` +// impl Name for Bytes { +// const NAME: &'static str = "BytesValue"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +/// `google.protobuf.Empty` +impl Message for () { + fn encode_raw(&self, _buf: &mut impl BufMut) {} + fn merge_field( + &mut self, + number: NonZeroU32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + skip_field(wire_type, number, buf, ctx) + } + fn encoded_len(&self) -> usize { + 0 + } + fn clear(&mut self) {} +} + +// /// `google.protobuf.Empty` +// impl Name for () { +// const NAME: &'static str = "Empty"; +// const PACKAGE: &'static str = "google.protobuf"; + +// fn type_url() -> String { +// googleapis_type_url_for::() +// } +// } + +// /// Compute the type URL for the given `google.protobuf` type, using `type.googleapis.com` as the +// /// authority for the URL. +// fn googleapis_type_url_for() -> String { +// format!("type.googleapis.com/{}.{}", T::PACKAGE, T::NAME) +// } + +// #[cfg(test)] +// mod tests { +// use super::*; + +// #[test] +// fn test_impl_name() { +// assert_eq!("BoolValue", bool::NAME); +// assert_eq!("google.protobuf", bool::PACKAGE); +// assert_eq!("google.protobuf.BoolValue", bool::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.BoolValue", +// bool::type_url() +// ); + +// assert_eq!("UInt32Value", u32::NAME); +// assert_eq!("google.protobuf", u32::PACKAGE); +// assert_eq!("google.protobuf.UInt32Value", u32::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.UInt32Value", +// u32::type_url() +// ); + +// assert_eq!("UInt64Value", u64::NAME); +// assert_eq!("google.protobuf", u64::PACKAGE); +// assert_eq!("google.protobuf.UInt64Value", u64::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.UInt64Value", +// u64::type_url() +// ); + +// assert_eq!("Int32Value", i32::NAME); +// assert_eq!("google.protobuf", i32::PACKAGE); +// assert_eq!("google.protobuf.Int32Value", i32::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.Int32Value", +// i32::type_url() +// ); + +// assert_eq!("Int64Value", i64::NAME); +// assert_eq!("google.protobuf", i64::PACKAGE); +// assert_eq!("google.protobuf.Int64Value", i64::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.Int64Value", +// i64::type_url() +// ); + +// assert_eq!("FloatValue", f32::NAME); +// assert_eq!("google.protobuf", f32::PACKAGE); +// assert_eq!("google.protobuf.FloatValue", f32::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.FloatValue", +// f32::type_url() +// ); + +// assert_eq!("DoubleValue", f64::NAME); +// assert_eq!("google.protobuf", f64::PACKAGE); +// assert_eq!("google.protobuf.DoubleValue", f64::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.DoubleValue", +// f64::type_url() +// ); + +// assert_eq!("StringValue", String::NAME); +// assert_eq!("google.protobuf", String::PACKAGE); +// assert_eq!("google.protobuf.StringValue", String::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.StringValue", +// String::type_url() +// ); + +// assert_eq!("BytesValue", Vec::::NAME); +// assert_eq!("google.protobuf", Vec::::PACKAGE); +// assert_eq!("google.protobuf.BytesValue", Vec::::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.BytesValue", +// Vec::::type_url() +// ); + +// assert_eq!("BytesValue", Bytes::NAME); +// assert_eq!("google.protobuf", Bytes::PACKAGE); +// assert_eq!("google.protobuf.BytesValue", Bytes::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.BytesValue", +// Bytes::type_url() +// ); + +// assert_eq!("Empty", <()>::NAME); +// assert_eq!("google.protobuf", <()>::PACKAGE); +// assert_eq!("google.protobuf.Empty", <()>::full_name()); +// assert_eq!( +// "type.googleapis.com/google.protobuf.Empty", +// <()>::type_url() +// ); +// } +// } diff --git a/patch/prost-derive/Cargo.toml b/patch/prost-derive/Cargo.toml new file mode 100644 index 0000000..c91b359 --- /dev/null +++ b/patch/prost-derive/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "prost-derive" +readme = "README.md" +description = "Generate encoding and decoding implementations for Prost annotated types." +version = "0.14.1" +authors = [ + "Dan Burkert ", + "Lucio Franco ", + "Casper Meijn ", + "Tokio Contributors ", +] +license = "Apache-2.0" +repository = "https://github.com/tokio-rs/prost" +edition = "2021" +rust-version = "1.71.1" + +[lib] +proc-macro = true + +[dependencies] +anyhow = "1.0.1" +itertools = ">=0.10.1, <=0.14" +proc-macro2 = "1.0" +quote = "1" +syn = { version = "2", features = ["extra-traits"] } diff --git a/patch/prost-derive/LICENSE b/patch/prost-derive/LICENSE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/patch/prost-derive/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/patch/prost-derive/README.md b/patch/prost-derive/README.md new file mode 100644 index 0000000..a51050e --- /dev/null +++ b/patch/prost-derive/README.md @@ -0,0 +1,16 @@ +[![Documentation](https://docs.rs/prost-derive/badge.svg)](https://docs.rs/prost-derive/) +[![Crate](https://img.shields.io/crates/v/prost-derive.svg)](https://crates.io/crates/prost-derive) + +# prost-derive + +`prost-derive` handles generating encoding and decoding implementations for Rust +types annotated with `prost` annotation. For the most part, users of `prost` +shouldn't need to interact with `prost-derive` directly. + +## License + +`prost-derive` is distributed under the terms of the Apache License (Version 2.0). + +See [LICENSE](../LICENSE) for details. + +Copyright 2017 Dan Burkert diff --git a/patch/prost-derive/src/field/group.rs b/patch/prost-derive/src/field/group.rs new file mode 100644 index 0000000..4eae27b --- /dev/null +++ b/patch/prost-derive/src/field/group.rs @@ -0,0 +1,137 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{Meta, Path}; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { + let mut group = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("group", attr) { + set_bool(&mut group, "duplicate group attributes")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !group { + return Ok(None); + } + + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for group field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("group field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) }; + #prost_path::encoding::group::encode(TAG, msg, buf); + } + }, + Label::Required => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::group::encode(TAG, &#ident, buf);} + }, + Label::Repeated => quote! { + for msg in &#ident { + const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) }; + #prost_path::encoding::group::encode(TAG, msg, buf); + } + }, + } + } + + pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + #prost_path::encoding::group::merge( + tag, + wire_type, + #ident.get_or_insert_with(::core::default::Default::default), + buf, + ctx, + ) + }, + Label::Required => quote! { + #prost_path::encoding::group::merge(tag, wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + #prost_path::encoding::group::merge_repeated(tag, wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::group::encoded_len(#tag, msg)}) + }, + Label::Required => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::group::encoded_len(TAG, &#ident)} + }, + Label::Repeated => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::group::encoded_len_repeated(TAG, &#ident)} + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/patch/prost-derive/src/field/map.rs b/patch/prost-derive/src/field/map.rs new file mode 100644 index 0000000..a79c722 --- /dev/null +++ b/patch/prost-derive/src/field/map.rs @@ -0,0 +1,411 @@ +use anyhow::{bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token}; + +use crate::field::{scalar, set_option, tag_attr}; + +#[derive(Clone, Debug)] +pub enum MapTy { + HashMap, + BTreeMap, + IndexMap, +} + +impl MapTy { + fn from_str(s: &str) -> Option { + match s { + "map" | "hash_map" => Some(MapTy::HashMap), + "btree_map" => Some(MapTy::BTreeMap), + "index_map" => Some(MapTy::IndexMap), + _ => None, + } + } + + fn module(&self) -> Ident { + match *self { + MapTy::HashMap => Ident::new("hash_map", Span::call_site()), + MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()), + MapTy::IndexMap => Ident::new("index_map", Span::call_site()), + } + } + + fn lib(&self) -> TokenStream { + match self { + MapTy::HashMap => quote! { std::collections }, + MapTy::BTreeMap => quote! { prost::alloc::collections }, + MapTy::IndexMap => quote! { prost::indexmap }, + } + } +} + +fn fake_scalar(ty: scalar::Ty) -> scalar::Field { + let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); + scalar::Field { + ty, + kind, + tag: 0, // Not used here + } +} + +#[derive(Clone)] +pub struct Field { + pub map_ty: MapTy, + pub key_ty: scalar::Ty, + pub value_ty: ValueTy, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { + let mut types = None; + let mut tag = None; + + for attr in attrs { + if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(map_ty) = attr + .path() + .get_ident() + .and_then(|i| MapTy::from_str(&i.to_string())) + { + let (k, v): (String, String) = match attr { + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Str(lit), .. + }), + .. + }) => { + let items = lit.value(); + let mut items = items.split(',').map(ToString::to_string); + let k = items.next().unwrap(); + let v = match items.next() { + Some(k) => k, + None => bail!("invalid map attribute: must have key and value types"), + }; + if items.next().is_some() { + bail!("invalid map attribute: {:?}", attr); + } + (k, v) + } + Meta::List(meta_list) => { + let nested = meta_list + .parse_args_with(Punctuated::::parse_terminated)? + .into_iter() + .collect::>(); + if nested.len() != 2 { + bail!("invalid map attribute: must contain key and value types"); + } + (nested[0].to_string(), nested[1].to_string()) + } + _ => return Ok(None), + }; + set_option( + &mut types, + (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?), + "duplicate map type attribute", + )?; + } else { + return Ok(None); + } + } + + Ok(match (types, tag.or(inferred_tag)) { + (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field { + map_ty, + key_ty, + value_ty, + tag, + }), + _ => None, + }) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { + Field::new(attrs, None) + } + + /// Returns a statement which encodes the map field. + pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let ke = quote!(#prost_path::encoding::#key_mod::encode); + let kl = quote!(#prost_path::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encode_with_default( + #ke, + #kl, + #prost_path::encoding::int32::encode, + #prost_path::encoding::int32::encoded_len, + &(#default), + TAG, + &#ident, + buf, + );} + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let ve = quote!(#prost_path::encoding::#val_mod::encode); + let vl = quote!(#prost_path::encoding::#val_mod::encoded_len); + quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encode( + #ke, + #kl, + #ve, + #vl, + TAG, + &#ident, + buf, + );} + } + } + ValueTy::Message => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encode( + #ke, + #kl, + #prost_path::encoding::message::encode, + #prost_path::encoding::message::encoded_len, + TAG, + &#ident, + buf, + );} + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded key value pair + /// into the map. + pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let key_mod = self.key_ty.module(); + let km = quote!(#prost_path::encoding::#key_mod::merge); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + #prost_path::encoding::#module::merge_with_default( + #km, + #prost_path::encoding::int32::merge, + #default, + &mut #ident, + buf, + ctx, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vm = quote!(#prost_path::encoding::#val_mod::merge); + quote!(#prost_path::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) + } + ValueTy::Message => quote! { + #prost_path::encoding::#module::merge( + #km, + #prost_path::encoding::message::merge, + &mut #ident, + buf, + ctx, + ) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the map. + pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let kl = quote!(#prost_path::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encoded_len_with_default( + #kl, + #prost_path::encoding::int32::encoded_len, + &(#default), + TAG, + &#ident, + )} + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vl = quote!(#prost_path::encoding::#val_mod::encoded_len); + quote!({const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encoded_len(#kl, #vl, TAG, &#ident)}) + } + ValueTy::Message => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::#module::encoded_len( + #kl, + #prost_path::encoding::message::encoded_len, + TAG, + &#ident, + )} + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident.clear()) + } + + /// Returns methods to embed in the message. + pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option { + if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { + let key_ty = self.key_ty.rust_type(prost_path); + let key_ref_ty = self.key_ty.rust_ref_type(); + + let get = Ident::new(&format!("get_{ident}"), Span::call_site()); + let insert = Ident::new(&format!("insert_{ident}"), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; + + let get_doc = format!( + "Returns the enum value for the corresponding key in `{ident}`, \ + or `None` if the entry does not exist or it is not a valid enum value." + ); + let insert_doc = format!("Inserts a key value pair into `{ident}`."); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + }) + } else { + None + } + } + + /// Returns a newtype wrapper around the map, implementing nicer Debug + /// + /// The Debug tries to convert any enumerations met into the variants if possible, instead of + /// outputting the raw numbers. + pub fn debug(&self, prost_path: &Path, wrapper_name: TokenStream) -> TokenStream { + let type_name = match self.map_ty { + MapTy::HashMap => Ident::new("HashMap", Span::call_site()), + MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()), + MapTy::IndexMap => Ident::new("IndexMap", Span::call_site()), + }; + + // A fake field for generating the debug wrapper + let key_wrapper = fake_scalar(self.key_ty.clone()).debug(prost_path, quote!(KeyWrapper)); + let key = self.key_ty.rust_type(prost_path); + let value_wrapper = self.value_ty.debug(prost_path); + let libname = self.map_ty.lib(); + let fmt = quote! { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #key_wrapper + #value_wrapper + let mut builder = f.debug_map(); + for (k, v) in self.0 { + builder.entry(&KeyWrapper(k), &ValueWrapper(v)); + } + builder.finish() + } + }; + match &self.value_ty { + ValueTy::Scalar(ty) => { + if let scalar::Ty::Bytes(_) = *ty { + return quote! { + struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + self.0.fmt(f) + } + } + }; + } + + let value = ty.rust_type(prost_path); + quote! { + struct #wrapper_name<'a>(&'a ::#libname::#type_name<#key, #value>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + #fmt + } + } + } + ValueTy::Message => quote! { + struct #wrapper_name<'a, V: 'a>(&'a ::#libname::#type_name<#key, V>); + impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + }, + } + } +} + +fn key_ty_from_str(s: &str) -> Result { + let ty = scalar::Ty::from_str(s)?; + match ty { + scalar::Ty::Int32 + | scalar::Ty::Int64 + | scalar::Ty::Uint32 + | scalar::Ty::Uint64 + | scalar::Ty::Sint32 + | scalar::Ty::Sint64 + | scalar::Ty::Fixed32 + | scalar::Ty::Fixed64 + | scalar::Ty::Sfixed32 + | scalar::Ty::Sfixed64 + | scalar::Ty::Bool + | scalar::Ty::String(..) => Ok(ty), + _ => bail!("invalid map key type: {}", s), + } +} + +/// A map value type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueTy { + Scalar(scalar::Ty), + Message, +} + +impl ValueTy { + fn from_str(s: &str) -> Result { + if let Ok(ty) = scalar::Ty::from_str(s) { + Ok(ValueTy::Scalar(ty)) + } else if s.trim() == "message" { + Ok(ValueTy::Message) + } else { + bail!("invalid map value type: {}", s); + } + } + + /// Returns a newtype wrapper around the ValueTy for nicer debug. + /// + /// If the contained value is enumeration, it tries to convert it to the variant. If not, it + /// just forwards the implementation. + fn debug(&self, prost_path: &Path) -> TokenStream { + match self { + ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(prost_path, quote!(ValueWrapper)), + ValueTy::Message => quote!( + fn ValueWrapper(v: T) -> T { + v + } + ), + } + } +} diff --git a/patch/prost-derive/src/field/message.rs b/patch/prost-derive/src/field/message.rs new file mode 100644 index 0000000..8a0a289 --- /dev/null +++ b/patch/prost-derive/src/field/message.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{Meta, Path}; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { + let mut message = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("message", attr) { + set_bool(&mut message, "duplicate message attribute")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attribute")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !message { + return Ok(None); + } + + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for message field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("message field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) }; + #prost_path::encoding::message::encode(TAG, msg, buf); + } + }, + Label::Required => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::message::encode(TAG, &#ident, buf);} + }, + Label::Repeated => quote! { + for msg in &#ident { + const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) }; + #prost_path::encoding::message::encode(TAG, msg, buf); + } + }, + } + } + + pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + #prost_path::encoding::message::merge(wire_type, + #ident.get_or_insert_with(::core::default::Default::default), + buf, + ctx) + }, + Label::Required => quote! { + #prost_path::encoding::message::merge(wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + #prost_path::encoding::message::merge_repeated(wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::message::encoded_len(TAG, msg)}) + }, + Label::Required => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::message::encoded_len(TAG, &#ident)} + }, + Label::Repeated => quote! { + {const TAG: ::core::num::NonZeroU32 = unsafe { ::core::num::NonZeroU32::new_unchecked(#tag) };#prost_path::encoding::message::encoded_len_repeated(TAG, &#ident)} + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/patch/prost-derive/src/field/mod.rs b/patch/prost-derive/src/field/mod.rs new file mode 100644 index 0000000..d3922b1 --- /dev/null +++ b/patch/prost-derive/src/field/mod.rs @@ -0,0 +1,356 @@ +mod group; +mod map; +mod message; +mod oneof; +mod scalar; + +use std::fmt; +use std::slice; + +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::Path; +use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; + +#[derive(Clone)] +pub enum Field { + /// A scalar field. + Scalar(scalar::Field), + /// A message field. + Message(message::Field), + /// A map field. + Map(map::Field), + /// A oneof field. + Oneof(oneof::Field), + /// A group field. + Group(group::Field), +} + +impl Field { + /// Creates a new `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new(attrs: Vec, inferred_tag: Option) -> Result, Error> { + let attrs = prost_attrs(attrs)?; + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? { + Field::Message(field) + } else if let Some(field) = map::Field::new(&attrs, inferred_tag)? { + Field::Map(field) + } else if let Some(field) = oneof::Field::new(&attrs)? { + Field::Oneof(field) + } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? { + Field::Group(field) + } else { + bail!("no type attribute"); + }; + + Ok(Some(field)) + } + + /// Creates a new oneof `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new_oneof(attrs: Vec) -> Result, Error> { + let attrs = prost_attrs(attrs)?; + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new_oneof(&attrs)? { + Field::Message(field) + } else if let Some(field) = map::Field::new_oneof(&attrs)? { + Field::Map(field) + } else if let Some(field) = group::Field::new_oneof(&attrs)? { + Field::Group(field) + } else { + bail!("no type attribute for oneof field"); + }; + + Ok(Some(field)) + } + + pub fn tags(&self) -> Vec { + match *self { + Field::Scalar(ref scalar) => vec![scalar.tag], + Field::Message(ref message) => vec![message.tag], + Field::Map(ref map) => vec![map.tag], + Field::Oneof(ref oneof) => oneof.tags.clone(), + Field::Group(ref group) => vec![group.tag], + } + } + + /// Returns a statement which encodes the field. + pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encode(prost_path, ident), + Field::Message(ref message) => message.encode(prost_path, ident), + Field::Map(ref map) => map.encode(prost_path, ident), + Field::Oneof(ref oneof) => oneof.encode(ident), + Field::Group(ref group) => group.encode(prost_path, ident), + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// value into the field. + pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.merge(prost_path, ident), + Field::Message(ref message) => message.merge(prost_path, ident), + Field::Map(ref map) => map.merge(prost_path, ident), + Field::Oneof(ref oneof) => oneof.merge(ident), + Field::Group(ref group) => group.merge(prost_path, ident), + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encoded_len(prost_path, ident), + Field::Map(ref map) => map.encoded_len(prost_path, ident), + Field::Message(ref msg) => msg.encoded_len(prost_path, ident), + Field::Oneof(ref oneof) => oneof.encoded_len(ident), + Field::Group(ref group) => group.encoded_len(prost_path, ident), + } + } + + /// Returns a statement which clears the field. + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.clear(ident), + Field::Message(ref message) => message.clear(ident), + Field::Map(ref map) => map.clear(ident), + Field::Oneof(ref oneof) => oneof.clear(ident), + Field::Group(ref group) => group.clear(ident), + } + } + + pub fn default(&self, prost_path: &Path) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.default(prost_path), + _ => quote!(::core::default::Default::default()), + } + } + + /// Produces the fragment implementing debug for the given field. + pub fn debug(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => { + let wrapper = scalar.debug(prost_path, quote!(ScalarWrapper)); + quote! { + { + #wrapper + ScalarWrapper(&#ident) + } + } + } + Field::Map(ref map) => { + let wrapper = map.debug(prost_path, quote!(MapWrapper)); + quote! { + { + #wrapper + MapWrapper(&#ident) + } + } + } + _ => quote!(&#ident), + } + } + + pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option { + match *self { + Field::Scalar(ref scalar) => scalar.methods(ident), + Field::Map(ref map) => map.methods(prost_path, ident), + _ => None, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Label { + /// An optional field. + Optional, + /// A required field. + Required, + /// A repeated field. + Repeated, +} + +impl Label { + fn as_str(self) -> &'static str { + match self { + Label::Optional => "optional", + Label::Required => "required", + Label::Repeated => "repeated", + } + } + + fn variants() -> slice::Iter<'static, Label> { + const VARIANTS: &[Label] = &[Label::Optional, Label::Required, Label::Repeated]; + VARIANTS.iter() + } + + /// Parses a string into a field label. + /// If the string doesn't match a field label, `None` is returned. + fn from_attr(attr: &Meta) -> Option