From 8f2a2d615028b06f05a6fd597b6a3a9f18cbf16a Mon Sep 17 00:00:00 2001 From: Artem Goncharov Date: Tue, 8 Apr 2025 14:38:48 +0200 Subject: [PATCH] Fix wrong enum serde macros In the previous change new serde macros were added to the reused class what resulted in sdk structs being partly corrupted. Make response structs public Change-Id: Ia53f7b21d58b23cc4647f891b5b9744c7fc5556d --- codegenerator/common/rust.py | 19 +++++++------------ codegenerator/rust_types.py | 18 +++++++++++++++++- codegenerator/templates/rust_types/impl.rs.j2 | 8 ++++---- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/codegenerator/common/rust.py b/codegenerator/common/rust.py index 4e50516..161917b 100644 --- a/codegenerator/common/rust.py +++ b/codegenerator/common/rust.py @@ -468,11 +468,7 @@ class EnumKind(BaseModel): @property def type_hint(self): if isinstance(self.data_type, Struct): - print(f"Getting type hint of {self.data_type}") - try: - return self.data_type.name + self.data_type.static_lifetime - except Exception as ex: - print(f"Error {ex}") + return self.data_type.name + self.data_type.static_lifetime return self.data_type.type_hint @property @@ -488,12 +484,12 @@ class Enum(BaseCompoundType): _kind_type_class = EnumKind @property - def derive_container_macros(self) -> str: + def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property - def serde_container_macros(self) -> str: - return "#[serde(untagged)]" + def serde_container_macros(self) -> str | None: + return None @property def type_hint(self): @@ -529,16 +525,15 @@ class StringEnum(BaseCompoundType): imports: set[str] = {"serde::Deserialize", "serde::Serialize"} lifetimes: set[str] = set() builder_container_macros: str | None = None - serde_macros: set[str] | None = None original_data_type: BaseCompoundType | BaseCompoundType | None = None @property - def derive_container_macros(self) -> str: + def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property - def serde_container_macros(self) -> str: - return "#[serde(untagged)]" + def serde_container_macros(self) -> str | None: + return None @property def type_hint(self): diff --git a/codegenerator/rust_types.py b/codegenerator/rust_types.py index 8ee29e5..b924094 100644 --- a/codegenerator/rust_types.py +++ b/codegenerator/rust_types.py @@ -50,9 +50,25 @@ class BoolString(common.BasePrimitiveType): clap_macros: set[str] = set() +class Enum(common_rust.Enum): + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + + +class StringEnum(common_rust.StringEnum): + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + + class ResponseTypeManager(common_rust.TypeManager): primitive_type_mapping = {} - data_type_mapping = {model.Struct: common_rust.StructResponse} + data_type_mapping = { + model.Struct: common_rust.StructResponse, + model.Enum: Enum, + } + string_enum_class: Type[StringEnum] | StringEnum = StringEnum def get_model_name(self, model_ref: model.Reference | None) -> str: """Get the localized model type name diff --git a/codegenerator/templates/rust_types/impl.rs.j2 b/codegenerator/templates/rust_types/impl.rs.j2 index a513a7f..bf82284 100644 --- a/codegenerator/templates/rust_types/impl.rs.j2 +++ b/codegenerator/templates/rust_types/impl.rs.j2 @@ -27,7 +27,7 @@ use {{ mod }}; {%- if data_type.fields %} /// {{ target_class_name }} response representation #[derive(Clone, Deserialize, Serialize)] - struct {{ data_type.name }} { + pub struct {{ data_type.name }} { {%- for k, v in data_type.fields | dictsort %} {% if not (operation_type == "list" and k in ["links"]) %} {{ macros.docstring(v.description, indent=4) }} @@ -48,7 +48,7 @@ use {{ mod }}; /// {{ target_class_name }} response representation #[derive(Deserialize, Serialize)] #[derive(Clone)] - struct {{ class_name }}( + pub struct {{ class_name }}( {%- for field in data_type.tuple_fields %} {{ field.type_hint }}, {%- endfor %} @@ -57,7 +57,7 @@ use {{ mod }}; {%- elif data_type.__class__.__name__ == "HashMapResponse" %} /// Response data as HashMap type #[derive(Deserialize, Serialize)] - struct {{ class_name }}(HashMap); + pub struct {{ class_name }}(HashMap); {%- endif %} {%- endwith %} @@ -67,7 +67,7 @@ use {{ mod }}; /// `{{ subtype.name }}` type #[derive(Clone, Debug)] #[derive(Deserialize, Serialize)] - {{ subtype.base_type }} {{ subtype.name }} { + pub {{ subtype.base_type }} {{ subtype.name }} { {%- for k, v in subtype.fields | dictsort %} {{ v.local_name }}: {{ v.type_hint }}, {%- endfor %}