# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. # import logging import re from typing import Type, Any, Generator, Tuple from pydantic import BaseModel from codegenerator.common import BasePrimitiveType from codegenerator.common import BaseCombinedType from codegenerator.common import BaseCompoundType from codegenerator import model from codegenerator import common CODEBLOCK_RE = re.compile(r"```(\w*)$") BASIC_FIELDS = [ "id", "name", "title", "created_at", "updated_at", "uuid", "state", "status", "operating_status", ] class Boolean(BasePrimitiveType): """Basic Boolean""" type_hint: str = "bool" imports: set[str] = set() clap_macros: set[str] = {"action=clap::ArgAction::Set"} original_data_type: BaseCompoundType | BaseCompoundType | None = None def get_sample(self): return "false" def get_sdk_setter( self, source_var_name: str, sdk_mod_path: str, into: bool = False ) -> str: if into: return f"*{source_var_name}" else: return f"*{source_var_name}" class Number(BasePrimitiveType): format: str | None = None imports: set[str] = set() clap_macros: set[str] = set() original_data_type: BaseCompoundType | BaseCompoundType | None = None @property def type_hint(self): if self.format == "float": return "f32" elif self.format == "double": return "f64" else: return "f32" def get_sample(self): return "123" class Integer(BasePrimitiveType): format: str | None = None imports: set[str] = set() clap_macros: set[str] = set() original_data_type: BaseCompoundType | BaseCompoundType | None = None @property def type_hint(self): if self.format == "int32": return "i32" elif self.format == "int64": return "i64" return "i32" def get_sample(self): return "123" class Null(BasePrimitiveType): type_hint: str = "Value" imports: set[str] = {"serde_json::Value"} builder_macros: set[str] = set() clap_macros: set[str] = set() original_data_type: BaseCompoundType | BaseCompoundType | None = None def get_sample(self): return "Value::Null" class String(BasePrimitiveType): format: str | None = None type_hint: str = "String" builder_macros: set[str] = {"setter(into)"} # NOTE(gtema): it is not possible to override field with computed # property, thus it must be a property here @property def imports(self) -> set[str]: return set() def get_sample(self): return '"foo"' class SecretString(String): type_hint: str = "SecretString" @property def imports(self) -> set[str]: return { "secrecy::SecretString", "crate::api::common::serialize_sensitive_string", "crate::api::common::serialize_sensitive_optional_string", } class JsonValue(BasePrimitiveType): type_hint: str = "Value" builder_macros: set[str] = {"setter(into)"} def get_sample(self): return "json!({})" @property def imports(self): imports: set[str] = {"serde_json::Value"} return imports class Option(BaseCombinedType): base_type: str = "Option" item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType original_data_type: BaseCompoundType | BaseCompoundType | None = None @property def type_hint(self): return f"Option<{self.item_type.type_hint}>" @property def lifetimes(self): return self.item_type.lifetimes @property def imports(self): return self.item_type.imports @property def builder_macros(self): macros = {"setter(into)"} wrapped_macros = self.item_type.builder_macros if "private" in wrapped_macros: macros = wrapped_macros return macros @property def clap_macros(self): return self.item_type.clap_macros def get_sample(self): return self.item_type.get_sample() def get_sdk_setter( self, source_var_name: str, sdk_mod_path: str, into: bool = False ) -> str: if into: return f"{self.item_type.get_sdk_setter(source_var_name, sdk_mod_path, into=into)}" else: return f"{source_var_name}.clone().map(Into::into)" class Array(BaseCombinedType): base_type: str = "vec" item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType @property def type_hint(self): return f"Vec<{self.item_type.type_hint}>" @property def lifetimes(self): return self.item_type.lifetimes @property def imports(self): return self.item_type.imports @property def builder_macros(self): if isinstance(self.item_type, Array): macros = {"private"} else: macros = {"setter(into)"} return macros def get_sample(self): return ( "Vec::from([" + self.item_type.get_sample() + (".into()" if isinstance(self.item_type, String) else "") + "])" ) @property def clap_macros(self) -> set[str]: return self.item_type.clap_macros @property def requires_builder_private_setter(self): if isinstance(self.item_type, Array): return True else: return False class CommaSeparatedList(BaseCombinedType): item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType @property def type_hint(self): return f"CommaSeparatedList<{self.item_type.type_hint}>" @property def lifetimes(self): return self.item_type.lifetimes @property def imports(self): imports: set[str] = set() imports.update(self.item_type.imports) return imports @property def clap_macros(self) -> set[str]: return set() class BTreeSet(BaseCombinedType): item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType builder_macros: set[str] = {"setter(into)"} @property def type_hint(self): return f"BTreeSet<{self.item_type.type_hint}>" @property def lifetimes(self): return self.item_type.lifetimes @property def imports(self): imports = self.item_type.imports imports.add("std::collections::BTreeSet") return imports class Dictionary(BaseCompoundType): base_type: str = "dict" value_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType @property def imports(self): imports: set[str] = {"std::collections::BTreeMap"} imports.update(self.value_type.imports) imports.add("structable::{StructTable, StructTableOptions}") return imports @property def type_hint(self): return f"BTreeMap" @property def lifetimes(self): return set() class StructField(BaseModel): local_name: str remote_name: str _description: str | None = None data_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType is_optional: bool = True is_nullable: bool = False def __init__(self, description: str | None = None, **data): super().__init__(**data) if description is not None: self._description = description @property def type_hint(self): typ_hint = self.data_type.type_hint if self.is_optional: typ_hint = f"Option<{typ_hint}>" return typ_hint @property def description(self): """Description getter necessary for being able to override the property""" return self._description class Struct(BaseCompoundType): base_type: str = "struct" fields: dict[str, StructField] = {} field_type_class_: Type[StructField] | StructField = StructField additional_fields_type: ( BasePrimitiveType | BaseCombinedType | BaseCompoundType | None ) = None pattern_properties: ( BasePrimitiveType | BaseCombinedType | BaseCompoundType | None ) = None @property def type_hint(self): return self.name + ( f"<{', '.join(self.lifetimes)}>" if self.lifetimes else "" ) @property def imports(self): imports: set[str] = set() field_types = [x.data_type for x in self.fields.values()] if len(field_types) > 1 or ( len(field_types) == 1 and not isinstance(field_types[0], Null) and not isinstance(field_types[0], Dictionary) and not isinstance(field_types[0], Array) ): # We use structure only if it is not consisting from only Null imports.add("serde::Deserialize") imports.add("serde::Serialize") for field_type in field_types: imports.update(field_type.imports) if self.additional_fields_type: imports.add("std::collections::BTreeMap") imports.update(self.additional_fields_type.imports) return imports @property def lifetimes(self): lifetimes_: set[str] = set() for field in self.fields.values(): if field.data_type.lifetimes: lifetimes_.update(field.data_type.lifetimes) return lifetimes_ @property def clap_macros(self) -> set[str]: return set() class StructFieldResponse(StructField): """Response Structure Field""" @property def type_hint(self): typ_hint = self.data_type.type_hint if self.is_optional and not typ_hint.startswith("Option<"): typ_hint = f"Option<{typ_hint}>" return typ_hint @property def serde_macros(self): macros = set() try: macros.update(self.data_type.get_serde_macros(self.is_optional)) except Exception: pass if self.is_optional: macros.add("default") if self.local_name != self.remote_name: macros.add(f'rename="{self.remote_name}"') if len(macros) > 0: return f"#[serde({', '.join(sorted(macros))})]" return "" def get_structable_macros( self, struct: "StructResponse", service_name: str, resource_name: str, operation_type: str, ): macros = set() if self.is_optional or self.data_type.type_hint.startswith("Option<"): macros.add("optional") if self.local_name != self.remote_name: macros.add(f'title="{self.remote_name}"') # Fully Qualified Attribute Name fqan: str = ".".join( [service_name, resource_name, self.remote_name] ).lower() # Check the known alias of the field by FQAN alias = common.FQAN_ALIAS_MAP.get(fqan) if operation_type in ["list", "list_from_struct"]: if ( "id" in struct.fields.keys() and not ( self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS ) ) or ( "id" not in struct.fields.keys() and (self.local_name not in list(struct.fields.keys())[-10:]) and not ( self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS ) ): # Only add "wide" flag if field is not in the basic fields AND # there is at least "id" field existing in the struct OR the # field is not in the first 10 macros.add("wide") if ( self.local_name == "state" and "status" not in struct.fields.keys() ): macros.add("status") elif ( self.local_name == "operating_status" and "status" not in struct.fields.keys() ): macros.add("status") if not ( # option of primitive isinstance(self.data_type, Option) and isinstance(self.data_type.item_type, BasePrimitiveType) ) and ( # not primitive not isinstance(self.data_type, BasePrimitiveType) # or explicitly Json or isinstance(self.data_type, JsonValue) ): macros.add("serialize") return f"#[structable({', '.join(sorted(macros))})]" class StructResponse(Struct): field_type_class_: Type[StructField] = StructFieldResponse @property def imports(self): imports: set[str] = { "serde::Deserialize", "serde::Serialize", "structable::{StructTable, StructTableOptions}", } for field in self.fields.values(): imports.update(field.data_type.imports) # In difference to the SDK and Input we do not currently handle # additional_fields of the struct in response # if self.additional_fields_type: # imports.add("std::collections::BTreeMap") # imports.update(self.additional_fields_type.imports) return imports @property def static_lifetime(self): """Return Rust `<'lc>` lifetimes representation""" return f"<{', '.join(self.lifetimes)}>" if self.lifetimes else "" class EnumKind(BaseModel): name: str description: str | None = None data_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType @property def type_hint(self): if isinstance(self.data_type, Struct): return self.data_type.name + self.data_type.static_lifetime return self.data_type.type_hint @property def clap_macros(self) -> set[str]: return set() class Enum(BaseCompoundType): base_type: str = "enum" kinds: dict[str, EnumKind] literals: list[Any] | None = None original_data_type: BaseCompoundType | BaseCompoundType | None = None _kind_type_class = EnumKind @property def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property def serde_container_macros(self) -> str | None: return None @property def type_hint(self): return self.name + ( f"<{', '.join(self.lifetimes)}>" if self.lifetimes else "" ) @property def imports(self): imports: set[str] = set() imports.add("serde::Deserialize") imports.add("serde::Serialize") for kind in self.kinds.values(): imports.update(kind.data_type.imports) return imports @property def lifetimes(self): lifetimes_: set[str] = set() for kind in self.kinds.values(): if kind.data_type.lifetimes: lifetimes_.update(kind.data_type.lifetimes) return lifetimes_ @property def clap_macros(self) -> set[str]: return set() class StringEnum(BaseCompoundType): base_type: str = "enum" variants: dict[str, set[str]] = {} imports: set[str] = {"serde::Deserialize", "serde::Serialize"} lifetimes: set[str] = set() builder_container_macros: str | None = None original_data_type: BaseCompoundType | BaseCompoundType | None = None allows_arbitrary_value: bool = False @property def derive_container_macros(self) -> str | None: return "#[derive(Debug, Deserialize, Clone, Serialize)]" @property def serde_container_macros(self) -> str | None: if self.allows_arbitrary_value: return "#[serde(untagged)]" return None @property def type_hint(self): """Get type hint""" return self.name @property def clap_macros(self) -> set[str]: """Return clap macros""" return set() @property def builder_macros(self) -> set[str]: """Return builder macros""" return set() def get_sample(self): """Generate sample data""" variant = sorted(self.variants.keys())[0] return f"{self.name}::{variant}" def variant_serde_macros(self, variant: str): """Return serde macros""" macros = set() vals = self.variants[variant] if len(vals) > 1: macros.add(f'rename(serialize = "{sorted(vals)[0]}")') for val in vals: macros.add(f'alias="{val}"') else: macros.add(f'rename = "{list(vals)[0]}"') return "#[serde(" + ", ".join(sorted(macros)) + ")]" class HashMapResponse(Dictionary): """Wrapper around a simple dictionary to implement Display trait""" # name: str | None = None lifetimes: set[str] = set() @property def type_hint(self): return f"HashMapString{self.value_type.type_hint.replace('<', '').replace('>', '')}" @property def imports(self): imports = self.value_type.imports imports.add("std::collections::BTreeMap") imports.add("structable::{StructTable, StructTableOptions}") return imports class TupleStruct(Struct): """Rust tuple struct without named fields""" base_type: str = "struct" tuple_fields: list[StructField] = [] @property def imports(self): imports: set[str] = set() for field in self.tuple_fields: imports.update(field.data_type.imports) imports.add("structable::{StructTable, StructTableOptions}") return imports class RequestParameter(BaseModel): """OpenAPI request parameter in the Rust SDK form""" remote_name: str local_name: str location: str data_type: BaseCombinedType | BasePrimitiveType | BaseCompoundType description: str | None = None is_required: bool = False is_flag: bool = False resource_link: str | None = None setter_name: str | None = None setter_type: str | None = None @property def type_hint(self): if not self.is_required and not isinstance(self.data_type, BTreeSet): return f"Option<{self.data_type.type_hint}>" return self.data_type.type_hint @property def lifetimes(self): return self.data_type.lifetimes class TypeManager: """Rust type manager The class is responsible for converting ADT models into types suitable for Rust. """ models: list = [] refs: dict[ model.Reference, BasePrimitiveType | BaseCombinedType | BaseCompoundType, ] = {} parameters: dict[str, Type[RequestParameter] | RequestParameter] = {} #: Base mapping of the primitive data-types base_primitive_type_mapping: dict[ Type[model.PrimitiveType], Type[BasePrimitiveType] | Type[BaseCombinedType], ] = { model.PrimitiveString: String, model.ConstraintString: String, model.PrimitiveNumber: Number, model.ConstraintNumber: Number, model.ConstraintInteger: Integer, model.PrimitiveBoolean: Boolean, model.PrimitiveNull: Null, model.PrimitiveAny: JsonValue, } #: Extension for primitives data-type mapping primitive_type_mapping: dict[ Type[model.PrimitiveType], Type[BasePrimitiveType] | Type[BaseCombinedType], ] #: Extensions of the data-type mapping data_type_mapping: dict[ Type[model.ADT], Type[BaseCombinedType] | Type[BaseCompoundType] ] #: Base data-type mapping base_data_type_mapping: dict[ Type[model.ADT], Type[BaseCombinedType] | Type[BaseCompoundType] ] = { model.Dictionary: Dictionary, model.Enum: Enum, model.Struct: Struct, model.Array: Array, model.CommaSeparatedList: CommaSeparatedList, model.Set: BTreeSet, } #: RequestParameter Type class request_parameter_class: Type[RequestParameter] = RequestParameter #: Option Type class option_type_class: Type[Option] | Option = Option #: StringEnum Type class string_enum_class: Type[StringEnum] | StringEnum = StringEnum #: List of the models to be ignored ignored_models: list[model.Reference] = [] root_name: str | None = "Body" def __init__(self): self.models = [] self.refs = {} self.parameters = {} # Set base mapping entries into the data_type_mapping for k, v in self.base_primitive_type_mapping.items(): if k not in self.primitive_type_mapping: self.primitive_type_mapping[k] = v for k, v in self.base_data_type_mapping.items(): if k not in self.data_type_mapping: self.data_type_mapping[k] = v def get_local_attribute_name(self, name: str) -> str: """Get localized attribute name""" name = name.replace(".", "_") attr_name = "_".join( x.lower() for x in re.split(common.SPLIT_NAME_RE, name) ) if attr_name in ["type", "self", "enum", "ref", "default"]: attr_name = f"_{attr_name}" return attr_name def get_remote_attribute_name(self, name: str) -> str: """Get remote attribute name This method can be used on the client side to be able to override remote attribute name as a local name on the SDK side. """ return name def get_model_name(self, model_ref: model.Reference | None) -> str: """Get the localized model type name""" if not model_ref: return "Request" name = "".join( x.capitalize() for x in re.split(common.SPLIT_NAME_RE, model_ref.name) ) if name[0].isdigit(): return "x" + name return name def _get_adt_by_reference(self, model_ref): for model_ in self.models: if model_.reference == model_ref: return model_ raise RuntimeError(f"Cannot find reference {model_ref}") def convert_model( self, type_model: model.PrimitiveType | model.ADT | model.Reference ) -> BasePrimitiveType | BaseCombinedType | BaseCompoundType: """Get local destination type from the ModelType""" # logging.debug("Get RustSDK type for %s", type_model) typ: BasePrimitiveType | BaseCombinedType | BaseCompoundType | None = ( None ) model_ref: model.Reference | None = None if isinstance(type_model, model.Reference): model_ref = type_model type_model = self._get_adt_by_reference(type_model) elif isinstance(type_model, model.ADT): # Direct composite type model_ref = type_model.reference else: # Primitive xtyp = self.primitive_type_mapping.get(type_model.__class__) if not xtyp: raise RuntimeError(f"No mapping for {type_model}") return xtyp(**type_model.model_dump()) # Composite/Compound type if model_ref and model_ref in self.refs: return self.refs[model_ref] if isinstance(type_model, model.Array): typ = self._get_array_type(type_model) elif isinstance(type_model, model.Struct): typ = self._get_struct_type(type_model) elif isinstance(type_model, model.OneOfType): typ = self._get_one_of_type(type_model) elif isinstance(type_model, model.Dictionary): typ = self.data_type_mapping[model.Dictionary]( name=self.get_model_name(type_model.reference), value_type=self.convert_model(type_model.value_type), ) elif isinstance(type_model, model.CommaSeparatedList): typ = self.data_type_mapping[model.CommaSeparatedList]( item_type=self.convert_model(type_model.item_type) ) elif isinstance(type_model, model.Set): typ = self.data_type_mapping[model.Set]( item_type=self.convert_model(type_model.item_type) ) elif isinstance(type_model, model.Enum): if len(type_model.base_types) > 1: if model.PrimitiveBoolean in type_model.base_types: # enum literals supporting also bools are most likely # bool + string -> just keep bool on the Rust side typ = Boolean() else: raise RuntimeError( "Rust model does not support multitype enums yet" f" {type_model}" ) elif len(type_model.base_types) == 1: base_type = type_model.base_types[0] if base_type is model.ConstraintString: variants: dict[str, set[str]] = {} try: if None in type_model.literals: # TODO(gtema): make parent nullable or add "null" # as enum value type_model.literals.remove(None) for lit in {x.lower() for x in type_model.literals}: val = "".join( [ x.capitalize() for x in re.split( common.SPLIT_NAME_RE, lit ) ] ) if val and val[0].isdigit(): val = "_" + val vals = variants.setdefault(val, set()) for orig_val in type_model.literals: if orig_val.lower() == lit: vals.add(orig_val) typ = self.string_enum_class( name=self.get_model_name(type_model.reference), variants=variants, ) except Exception: logging.exception( "Error processing enum: %s", type_model ) elif base_type is model.ConstraintInteger: typ = self.primitive_type_mapping[ model.ConstraintInteger ]() elif base_type is model.ConstraintNumber: typ = self.primitive_type_mapping[model.ConstraintNumber]() elif base_type is model.PrimitiveBoolean: typ = self.primitive_type_mapping[model.PrimitiveBoolean]() if not typ: raise RuntimeError( f"Cannot map model type {type_model.__class__.__name__} to" f" Rust type [{type_model}]" ) if not model_ref: model_ref = model.Reference( name=self.root_name, type=typ.__class__ ) self.refs[model_ref] = typ return typ def _get_array_type(self, type_model: model.Array) -> Array: """Convert `model.Array` into corresponding Rust SDK model""" return self.data_type_mapping[model.Array]( name=self.get_model_name(type_model.reference), item_type=self.convert_model(type_model.item_type), ) def _get_one_of_type( self, type_model: model.OneOfType ) -> BaseCompoundType | BaseCombinedType | BasePrimitiveType: """Convert `model.OneOfType` into Rust model""" kinds: list[dict] = [] is_nullable: bool = False result_data_type = None for kind in type_model.kinds: if isinstance(kind, model.PrimitiveNull): # Remove null from candidates and instead wrap with Option is_nullable = True continue kind_type = self.convert_model(kind) is_type_already_present = False for processed_kind_type in kinds: if ( isinstance(kind_type, BasePrimitiveType) and processed_kind_type["local"] == kind_type ): logging.debug( "Simplifying oneOf with same mapped type %s [%s]", kind, type_model, ) is_type_already_present = True break if not is_type_already_present: kinds.append( { "model": kind, "local": kind_type, "class": kind_type.__class__, } ) # Simplify certain oneOf combinations self._simplify_oneof_combinations(type_model, kinds) if len(kinds) == 2: list_type = [ x["local"] for x in kinds if x["class"] == self.data_type_mapping[model.Array] ] if list_type: lt: BaseCombinedType = list_type[0] # Typ + list[Typ] => Vec item_type = [ x["local"] for x in kinds if x["class"] != self.data_type_mapping[model.Array] ][0] if item_type.__class__ == lt.item_type.__class__: result_data_type = self.data_type_mapping[model.Array]( item_type=item_type, description=sanitize_rust_docstrings( type_model.description ), ) # logging.debug("Replacing Typ + list[Typ] with list[Typ]") elif len(kinds) == 1: result_data_type = kinds[0]["local"] if not result_data_type: enum_class = self.data_type_mapping[model.Enum] result_data_type = enum_class( name=self.get_model_name(type_model.reference), kinds={} ) cnt: int = 0 for kind_data in kinds: cnt += 1 kind_data_type = kind_data["local"] kind_description: str | None = None if isinstance(kind_data["model"], model.ADT): kind_name = self.get_model_name(kind_data["model"]) kind_description = kind_data["model"].description else: kind_name = f"F{cnt}" enum_kind = enum_class._kind_type_class( name=kind_name, description=sanitize_rust_docstrings(kind_description), data_type=kind_data_type, ) result_data_type.kinds[enum_kind.name] = enum_kind if is_nullable: result_data_type = self.option_type_class( item_type=result_data_type ) return result_data_type def _get_struct_type(self, type_model: model.Struct) -> Struct: """Convert model.Struct into Rust `Struct`""" struct_class = self.data_type_mapping[model.Struct] mod = struct_class( name=self.get_model_name(type_model.reference), description=sanitize_rust_docstrings(type_model.description), ) field_class = mod.field_type_class_ for field_name, field in type_model.fields.items(): is_nullable: bool = False field_data_type = self.convert_model(field.data_type) if isinstance(field_data_type, self.option_type_class): # Unwrap Option into "is_nullable" NOTE: but perhaps # Option