diff --git a/codegenerator/common/rust.py b/codegenerator/common/rust.py index 4e50516..161917b 100644 --- a/codegenerator/common/rust.py +++ b/codegenerator/common/rust.py @@ -468,11 +468,7 @@ class EnumKind(BaseModel): @property def type_hint(self): if isinstance(self.data_type, Struct): - print(f"Getting type hint of {self.data_type}") - try: - return self.data_type.name + self.data_type.static_lifetime - except Exception as ex: - print(f"Error {ex}") + return self.data_type.name + self.data_type.static_lifetime return self.data_type.type_hint @property @@ -488,12 +484,12 @@ class Enum(BaseCompoundType): _kind_type_class = EnumKind @property - def derive_container_macros(self) -> str: + def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property - def serde_container_macros(self) -> str: - return "#[serde(untagged)]" + def serde_container_macros(self) -> str | None: + return None @property def type_hint(self): @@ -529,16 +525,15 @@ class StringEnum(BaseCompoundType): imports: set[str] = {"serde::Deserialize", "serde::Serialize"} lifetimes: set[str] = set() builder_container_macros: str | None = None - serde_macros: set[str] | None = None original_data_type: BaseCompoundType | BaseCompoundType | None = None @property - def derive_container_macros(self) -> str: + def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property - def serde_container_macros(self) -> str: - return "#[serde(untagged)]" + def serde_container_macros(self) -> str | None: + return None @property def type_hint(self): diff --git a/codegenerator/rust_types.py b/codegenerator/rust_types.py index 8ee29e5..b924094 100644 --- a/codegenerator/rust_types.py +++ b/codegenerator/rust_types.py @@ -50,9 +50,25 @@ class BoolString(common.BasePrimitiveType): clap_macros: set[str] = set() +class Enum(common_rust.Enum): + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + + +class StringEnum(common_rust.StringEnum): + @property + def serde_container_macros(self) -> str: + return "#[serde(untagged)]" + + class ResponseTypeManager(common_rust.TypeManager): primitive_type_mapping = {} - data_type_mapping = {model.Struct: common_rust.StructResponse} + data_type_mapping = { + model.Struct: common_rust.StructResponse, + model.Enum: Enum, + } + string_enum_class: Type[StringEnum] | StringEnum = StringEnum def get_model_name(self, model_ref: model.Reference | None) -> str: """Get the localized model type name diff --git a/codegenerator/templates/rust_types/impl.rs.j2 b/codegenerator/templates/rust_types/impl.rs.j2 index a513a7f..bf82284 100644 --- a/codegenerator/templates/rust_types/impl.rs.j2 +++ b/codegenerator/templates/rust_types/impl.rs.j2 @@ -27,7 +27,7 @@ use {{ mod }}; {%- if data_type.fields %} /// {{ target_class_name }} response representation #[derive(Clone, Deserialize, Serialize)] - struct {{ data_type.name }} { + pub struct {{ data_type.name }} { {%- for k, v in data_type.fields | dictsort %} {% if not (operation_type == "list" and k in ["links"]) %} {{ macros.docstring(v.description, indent=4) }} @@ -48,7 +48,7 @@ use {{ mod }}; /// {{ target_class_name }} response representation #[derive(Deserialize, Serialize)] #[derive(Clone)] - struct {{ class_name }}( + pub struct {{ class_name }}( {%- for field in data_type.tuple_fields %} {{ field.type_hint }}, {%- endfor %} @@ -57,7 +57,7 @@ use {{ mod }}; {%- elif data_type.__class__.__name__ == "HashMapResponse" %} /// Response data as HashMap type #[derive(Deserialize, Serialize)] - struct {{ class_name }}(HashMap); + pub struct {{ class_name }}(HashMap); {%- endif %} {%- endwith %} @@ -67,7 +67,7 @@ use {{ mod }}; /// `{{ subtype.name }}` type #[derive(Clone, Debug)] #[derive(Deserialize, Serialize)] - {{ subtype.base_type }} {{ subtype.name }} { + pub {{ subtype.base_type }} {{ subtype.name }} { {%- for k, v in subtype.fields | dictsort %} {{ v.local_name }}: {{ v.type_hint }}, {%- endfor %}