Files
codegenerator/codegenerator/common/rust.py
Artem Goncharov 721705d837 Rename operation_name with action_name in the metadata
Currently we comment the operation_name attribute in the metadata that
it is used as an action name. This only creates confusion especially if
we want to use something different as the operation_name (i.e.
operation_name or opertaion_type for neutron router results in
"action"). So in addition to the renaming of the metadata attribute
explicitly pass the metadata operation key as operation_name parameters
into the generator (when unset).

Change-Id: Ic04eafe5b6dea012ca18b9835cd5c86fefa87055
Signed-off-by: Artem Goncharov <artem.goncharov@gmail.com>
2025-06-05 15:02:09 +00:00

1460 lines
52 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import logging
import re
from typing import Type, Any, Generator, Tuple
from pydantic import BaseModel
from codegenerator.common import BasePrimitiveType
from codegenerator.common import BaseCombinedType
from codegenerator.common import BaseCompoundType
from codegenerator import model
from codegenerator import common
CODEBLOCK_RE = re.compile(r"```(\w*)$")
BASIC_FIELDS = [
"id",
"name",
"title",
"created_at",
"updated_at",
"uuid",
"state",
"status",
"operating_status",
]
class Boolean(BasePrimitiveType):
"""Basic Boolean"""
type_hint: str = "bool"
imports: set[str] = set()
clap_macros: set[str] = {"action=clap::ArgAction::Set"}
original_data_type: BaseCompoundType | BaseCompoundType | None = None
def get_sample(self):
return "false"
def get_sdk_setter(
self, source_var_name: str, sdk_mod_path: str, into: bool = False
) -> str:
if into:
return f"*{source_var_name}"
else:
return f"*{source_var_name}"
class Number(BasePrimitiveType):
format: str | None = None
imports: set[str] = set()
clap_macros: set[str] = set()
original_data_type: BaseCompoundType | BaseCompoundType | None = None
@property
def type_hint(self):
if self.format == "float":
return "f32"
elif self.format == "double":
return "f64"
else:
return "f32"
def get_sample(self):
return "123"
class Integer(BasePrimitiveType):
format: str | None = None
imports: set[str] = set()
clap_macros: set[str] = set()
original_data_type: BaseCompoundType | BaseCompoundType | None = None
@property
def type_hint(self):
if self.format == "int32":
return "i32"
elif self.format == "int64":
return "i64"
return "i32"
def get_sample(self):
return "123"
class Null(BasePrimitiveType):
type_hint: str = "Value"
imports: set[str] = {"serde_json::Value"}
builder_macros: set[str] = set()
clap_macros: set[str] = set()
original_data_type: BaseCompoundType | BaseCompoundType | None = None
def get_sample(self):
return "Value::Null"
class String(BasePrimitiveType):
format: str | None = None
type_hint: str = "String"
builder_macros: set[str] = {"setter(into)"}
# NOTE(gtema): it is not possible to override field with computed
# property, thus it must be a property here
@property
def imports(self) -> set[str]:
return set()
def get_sample(self):
return '"foo"'
class SecretString(String):
type_hint: str = "SecretString"
@property
def imports(self) -> set[str]:
return {
"secrecy::SecretString",
"crate::api::common::serialize_sensitive_string",
"crate::api::common::serialize_sensitive_optional_string",
}
class JsonValue(BasePrimitiveType):
type_hint: str = "Value"
builder_macros: set[str] = {"setter(into)"}
def get_sample(self):
return "json!({})"
@property
def imports(self):
imports: set[str] = {"serde_json::Value"}
return imports
class Option(BaseCombinedType):
base_type: str = "Option"
item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
original_data_type: BaseCompoundType | BaseCompoundType | None = None
@property
def type_hint(self):
return f"Option<{self.item_type.type_hint}>"
@property
def lifetimes(self):
return self.item_type.lifetimes
@property
def imports(self):
return self.item_type.imports
@property
def builder_macros(self):
macros = {"setter(into)"}
wrapped_macros = self.item_type.builder_macros
if "private" in wrapped_macros:
macros = wrapped_macros
return macros
@property
def clap_macros(self):
return self.item_type.clap_macros
def get_sample(self):
return self.item_type.get_sample()
def get_sdk_setter(
self, source_var_name: str, sdk_mod_path: str, into: bool = False
) -> str:
if into:
return f"{self.item_type.get_sdk_setter(source_var_name, sdk_mod_path, into=into)}"
else:
return f"{source_var_name}.clone().map(Into::into)"
class Array(BaseCombinedType):
base_type: str = "vec"
item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
@property
def type_hint(self):
return f"Vec<{self.item_type.type_hint}>"
@property
def lifetimes(self):
return self.item_type.lifetimes
@property
def imports(self):
return self.item_type.imports
@property
def builder_macros(self):
if isinstance(self.item_type, Array):
macros = {"private"}
else:
macros = {"setter(into)"}
return macros
def get_sample(self):
return (
"Vec::from(["
+ self.item_type.get_sample()
+ (".into()" if isinstance(self.item_type, String) else "")
+ "])"
)
@property
def clap_macros(self) -> set[str]:
return self.item_type.clap_macros
@property
def requires_builder_private_setter(self):
if isinstance(self.item_type, Array):
return True
else:
return False
class CommaSeparatedList(BaseCombinedType):
item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
@property
def type_hint(self):
return f"CommaSeparatedList<{self.item_type.type_hint}>"
@property
def lifetimes(self):
return self.item_type.lifetimes
@property
def imports(self):
imports: set[str] = set()
imports.update(self.item_type.imports)
return imports
@property
def clap_macros(self) -> set[str]:
return set()
class BTreeSet(BaseCombinedType):
item_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
builder_macros: set[str] = {"setter(into)"}
@property
def type_hint(self):
return f"BTreeSet<{self.item_type.type_hint}>"
@property
def lifetimes(self):
return self.item_type.lifetimes
@property
def imports(self):
imports = self.item_type.imports
imports.add("std::collections::BTreeSet")
return imports
class Dictionary(BaseCompoundType):
base_type: str = "dict"
value_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
@property
def imports(self):
imports: set[str] = {"std::collections::BTreeMap"}
imports.update(self.value_type.imports)
imports.add("structable::{StructTable, StructTableOptions}")
return imports
@property
def type_hint(self):
return f"BTreeMap<String, {self.value_type.type_hint}>"
@property
def lifetimes(self):
return set()
class StructField(BaseModel):
local_name: str
remote_name: str
_description: str | None = None
data_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
is_optional: bool = True
is_nullable: bool = False
def __init__(self, description: str | None = None, **data):
super().__init__(**data)
if description is not None:
self._description = description
@property
def type_hint(self):
typ_hint = self.data_type.type_hint
if self.is_optional:
typ_hint = f"Option<{typ_hint}>"
return typ_hint
@property
def description(self):
"""Description getter necessary for being able to override the property"""
return self._description
class Struct(BaseCompoundType):
base_type: str = "struct"
fields: dict[str, StructField] = {}
field_type_class_: Type[StructField] | StructField = StructField
additional_fields_type: (
BasePrimitiveType | BaseCombinedType | BaseCompoundType | None
) = None
pattern_properties: (
BasePrimitiveType | BaseCombinedType | BaseCompoundType | None
) = None
@property
def type_hint(self):
return self.name + (
f"<{', '.join(self.lifetimes)}>" if self.lifetimes else ""
)
@property
def imports(self):
imports: set[str] = set()
field_types = [x.data_type for x in self.fields.values()]
if len(field_types) > 1 or (
len(field_types) == 1
and not isinstance(field_types[0], Null)
and not isinstance(field_types[0], Dictionary)
and not isinstance(field_types[0], Array)
):
# We use structure only if it is not consisting from only Null
imports.add("serde::Deserialize")
imports.add("serde::Serialize")
for field_type in field_types:
imports.update(field_type.imports)
if self.additional_fields_type:
imports.add("std::collections::BTreeMap")
imports.update(self.additional_fields_type.imports)
return imports
@property
def lifetimes(self):
lifetimes_: set[str] = set()
for field in self.fields.values():
if field.data_type.lifetimes:
lifetimes_.update(field.data_type.lifetimes)
return lifetimes_
@property
def clap_macros(self) -> set[str]:
return set()
class StructFieldResponse(StructField):
"""Response Structure Field"""
@property
def type_hint(self):
typ_hint = self.data_type.type_hint
if self.is_optional and not typ_hint.startswith("Option<"):
typ_hint = f"Option<{typ_hint}>"
return typ_hint
@property
def serde_macros(self):
macros = set()
try:
macros.update(self.data_type.get_serde_macros(self.is_optional))
except Exception:
pass
if self.is_optional:
macros.add("default")
if self.local_name != self.remote_name:
macros.add(f'rename="{self.remote_name}"')
if len(macros) > 0:
return f"#[serde({', '.join(sorted(macros))})]"
return ""
def get_structable_macros(
self,
struct: "StructResponse",
service_name: str,
resource_name: str,
operation_type: str,
):
macros = set()
if self.is_optional or self.data_type.type_hint.startswith("Option<"):
macros.add("optional")
if self.local_name != self.remote_name:
macros.add(f'title="{self.remote_name}"')
# Fully Qualified Attribute Name
fqan: str = ".".join(
[service_name, resource_name, self.remote_name]
).lower()
# Check the known alias of the field by FQAN
alias = common.FQAN_ALIAS_MAP.get(fqan)
if operation_type in ["list", "list_from_struct"]:
if (
"id" in struct.fields.keys()
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
) or (
"id" not in struct.fields.keys()
and (self.local_name not in list(struct.fields.keys())[-10:])
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
):
# Only add "wide" flag if field is not in the basic fields AND
# there is at least "id" field existing in the struct OR the
# field is not in the first 10
macros.add("wide")
if (
self.local_name == "state"
and "status" not in struct.fields.keys()
):
macros.add("status")
elif (
self.local_name == "operating_status"
and "status" not in struct.fields.keys()
):
macros.add("status")
if not (
# option of primitive
isinstance(self.data_type, Option)
and isinstance(self.data_type.item_type, BasePrimitiveType)
) and (
# not primitive
not isinstance(self.data_type, BasePrimitiveType)
# or explicitly Json
or isinstance(self.data_type, JsonValue)
):
macros.add("serialize")
return f"#[structable({', '.join(sorted(macros))})]"
class StructResponse(Struct):
field_type_class_: Type[StructField] = StructFieldResponse
@property
def imports(self):
imports: set[str] = {
"serde::Deserialize",
"serde::Serialize",
"structable::{StructTable, StructTableOptions}",
}
for field in self.fields.values():
imports.update(field.data_type.imports)
# In difference to the SDK and Input we do not currently handle
# additional_fields of the struct in response
# if self.additional_fields_type:
# imports.add("std::collections::BTreeMap")
# imports.update(self.additional_fields_type.imports)
return imports
@property
def static_lifetime(self):
"""Return Rust `<'lc>` lifetimes representation"""
return f"<{', '.join(self.lifetimes)}>" if self.lifetimes else ""
class EnumKind(BaseModel):
name: str
description: str | None = None
data_type: BasePrimitiveType | BaseCombinedType | BaseCompoundType
@property
def type_hint(self):
if isinstance(self.data_type, Struct):
return self.data_type.name + self.data_type.static_lifetime
return self.data_type.type_hint
@property
def clap_macros(self) -> set[str]:
return set()
class Enum(BaseCompoundType):
base_type: str = "enum"
kinds: dict[str, EnumKind]
literals: list[Any] | None = None
original_data_type: BaseCompoundType | BaseCompoundType | None = None
_kind_type_class = EnumKind
@property
def derive_container_macros(self) -> str | None:
return "#[derive(Debug, Deserialize, Clone, Serialize)]"
@property
def serde_container_macros(self) -> str | None:
return None
@property
def type_hint(self):
return self.name + (
f"<{', '.join(self.lifetimes)}>" if self.lifetimes else ""
)
@property
def imports(self):
imports: set[str] = set()
imports.add("serde::Deserialize")
imports.add("serde::Serialize")
for kind in self.kinds.values():
imports.update(kind.data_type.imports)
return imports
@property
def lifetimes(self):
lifetimes_: set[str] = set()
for kind in self.kinds.values():
if kind.data_type.lifetimes:
lifetimes_.update(kind.data_type.lifetimes)
return lifetimes_
@property
def clap_macros(self) -> set[str]:
return set()
class StringEnum(BaseCompoundType):
base_type: str = "enum"
variants: dict[str, set[str]] = {}
imports: set[str] = {"serde::Deserialize", "serde::Serialize"}
lifetimes: set[str] = set()
builder_container_macros: str | None = None
original_data_type: BaseCompoundType | BaseCompoundType | None = None
allows_arbitrary_value: bool = False
@property
def derive_container_macros(self) -> str | None:
return "#[derive(Debug, Deserialize, Clone, Serialize)]"
@property
def serde_container_macros(self) -> str | None:
if self.allows_arbitrary_value:
return "#[serde(untagged)]"
return None
@property
def type_hint(self):
"""Get type hint"""
return self.name
@property
def clap_macros(self) -> set[str]:
"""Return clap macros"""
return set()
@property
def builder_macros(self) -> set[str]:
"""Return builder macros"""
return set()
def get_sample(self):
"""Generate sample data"""
variant = sorted(self.variants.keys())[0]
return f"{self.name}::{variant}"
def variant_serde_macros(self, variant: str):
"""Return serde macros"""
macros = set()
vals = self.variants[variant]
if len(vals) > 1:
macros.add(f'rename(serialize = "{sorted(vals)[0]}")')
for val in vals:
macros.add(f'alias="{val}"')
else:
macros.add(f'rename = "{list(vals)[0]}"')
return "#[serde(" + ", ".join(sorted(macros)) + ")]"
class HashMapResponse(Dictionary):
"""Wrapper around a simple dictionary to implement Display trait"""
# name: str | None = None
lifetimes: set[str] = set()
@property
def type_hint(self):
return f"HashMapString{self.value_type.type_hint.replace('<', '').replace('>', '')}"
@property
def imports(self):
imports = self.value_type.imports
imports.add("std::collections::BTreeMap")
imports.add("structable::{StructTable, StructTableOptions}")
return imports
class TupleStruct(Struct):
"""Rust tuple struct without named fields"""
base_type: str = "struct"
tuple_fields: list[StructField] = []
@property
def imports(self):
imports: set[str] = set()
for field in self.tuple_fields:
imports.update(field.data_type.imports)
imports.add("structable::{StructTable, StructTableOptions}")
return imports
class RequestParameter(BaseModel):
"""OpenAPI request parameter in the Rust SDK form"""
remote_name: str
local_name: str
location: str
data_type: BaseCombinedType | BasePrimitiveType | BaseCompoundType
description: str | None = None
is_required: bool = False
is_flag: bool = False
resource_link: str | None = None
setter_name: str | None = None
setter_type: str | None = None
@property
def type_hint(self):
if not self.is_required and not isinstance(self.data_type, BTreeSet):
return f"Option<{self.data_type.type_hint}>"
return self.data_type.type_hint
@property
def lifetimes(self):
return self.data_type.lifetimes
class TypeManager:
"""Rust type manager
The class is responsible for converting ADT models into types suitable
for Rust.
"""
models: list = []
refs: dict[
model.Reference,
BasePrimitiveType | BaseCombinedType | BaseCompoundType,
] = {}
parameters: dict[str, Type[RequestParameter] | RequestParameter] = {}
#: Base mapping of the primitive data-types
base_primitive_type_mapping: dict[
Type[model.PrimitiveType],
Type[BasePrimitiveType] | Type[BaseCombinedType],
] = {
model.PrimitiveString: String,
model.ConstraintString: String,
model.PrimitiveNumber: Number,
model.ConstraintNumber: Number,
model.ConstraintInteger: Integer,
model.PrimitiveBoolean: Boolean,
model.PrimitiveNull: Null,
model.PrimitiveAny: JsonValue,
}
#: Extension for primitives data-type mapping
primitive_type_mapping: dict[
Type[model.PrimitiveType],
Type[BasePrimitiveType] | Type[BaseCombinedType],
]
#: Extensions of the data-type mapping
data_type_mapping: dict[
Type[model.ADT], Type[BaseCombinedType] | Type[BaseCompoundType]
]
#: Base data-type mapping
base_data_type_mapping: dict[
Type[model.ADT], Type[BaseCombinedType] | Type[BaseCompoundType]
] = {
model.Dictionary: Dictionary,
model.Enum: Enum,
model.Struct: Struct,
model.Array: Array,
model.CommaSeparatedList: CommaSeparatedList,
model.Set: BTreeSet,
}
#: RequestParameter Type class
request_parameter_class: Type[RequestParameter] = RequestParameter
#: Option Type class
option_type_class: Type[Option] | Option = Option
#: StringEnum Type class
string_enum_class: Type[StringEnum] | StringEnum = StringEnum
#: List of the models to be ignored
ignored_models: list[model.Reference] = []
root_name: str | None = "Body"
def __init__(self):
self.models = []
self.refs = {}
self.parameters = {}
# Set base mapping entries into the data_type_mapping
for k, v in self.base_primitive_type_mapping.items():
if k not in self.primitive_type_mapping:
self.primitive_type_mapping[k] = v
for k, v in self.base_data_type_mapping.items():
if k not in self.data_type_mapping:
self.data_type_mapping[k] = v
def get_local_attribute_name(self, name: str) -> str:
"""Get localized attribute name"""
name = name.replace(".", "_")
attr_name = "_".join(
x.lower() for x in re.split(common.SPLIT_NAME_RE, name)
)
if attr_name in ["type", "self", "enum", "ref", "default"]:
attr_name = f"_{attr_name}"
return attr_name
def get_remote_attribute_name(self, name: str) -> str:
"""Get remote attribute name
This method can be used on the client side to be able to override
remote attribute name as a local name on the SDK side.
"""
return name
def get_model_name(self, model_ref: model.Reference | None) -> str:
"""Get the localized model type name"""
if not model_ref:
return "Request"
name = "".join(
x.capitalize()
for x in re.split(common.SPLIT_NAME_RE, model_ref.name)
)
if name[0].isdigit():
return "x" + name
return name
def _get_adt_by_reference(self, model_ref):
for model_ in self.models:
if model_.reference == model_ref:
return model_
raise RuntimeError(f"Cannot find reference {model_ref}")
def convert_model(
self, type_model: model.PrimitiveType | model.ADT | model.Reference
) -> BasePrimitiveType | BaseCombinedType | BaseCompoundType:
"""Get local destination type from the ModelType"""
# logging.debug("Get RustSDK type for %s", type_model)
typ: BasePrimitiveType | BaseCombinedType | BaseCompoundType | None = (
None
)
model_ref: model.Reference | None = None
if isinstance(type_model, model.Reference):
model_ref = type_model
type_model = self._get_adt_by_reference(type_model)
elif isinstance(type_model, model.ADT):
# Direct composite type
model_ref = type_model.reference
else:
# Primitive
xtyp = self.primitive_type_mapping.get(type_model.__class__)
if not xtyp:
raise RuntimeError(f"No mapping for {type_model}")
return xtyp(**type_model.model_dump())
# Composite/Compound type
if model_ref and model_ref in self.refs:
return self.refs[model_ref]
if isinstance(type_model, model.Array):
typ = self._get_array_type(type_model)
elif isinstance(type_model, model.Struct):
typ = self._get_struct_type(type_model)
elif isinstance(type_model, model.OneOfType):
typ = self._get_one_of_type(type_model)
elif isinstance(type_model, model.Dictionary):
typ = self.data_type_mapping[model.Dictionary](
name=self.get_model_name(type_model.reference),
value_type=self.convert_model(type_model.value_type),
)
elif isinstance(type_model, model.CommaSeparatedList):
typ = self.data_type_mapping[model.CommaSeparatedList](
item_type=self.convert_model(type_model.item_type)
)
elif isinstance(type_model, model.Set):
typ = self.data_type_mapping[model.Set](
item_type=self.convert_model(type_model.item_type)
)
elif isinstance(type_model, model.Enum):
if len(type_model.base_types) > 1:
if model.PrimitiveBoolean in type_model.base_types:
# enum literals supporting also bools are most likely
# bool + string -> just keep bool on the Rust side
typ = Boolean()
else:
raise RuntimeError(
"Rust model does not support multitype enums yet"
f" {type_model}"
)
elif len(type_model.base_types) == 1:
base_type = type_model.base_types[0]
if base_type is model.ConstraintString:
variants: dict[str, set[str]] = {}
try:
if None in type_model.literals:
# TODO(gtema): make parent nullable or add "null"
# as enum value
type_model.literals.remove(None)
for lit in {x.lower() for x in type_model.literals}:
val = "".join(
[
x.capitalize()
for x in re.split(
common.SPLIT_NAME_RE, lit
)
]
)
if val and val[0].isdigit():
val = "_" + val
vals = variants.setdefault(val, set())
for orig_val in type_model.literals:
if orig_val.lower() == lit:
vals.add(orig_val)
typ = self.string_enum_class(
name=self.get_model_name(type_model.reference),
variants=variants,
)
except Exception:
logging.exception(
"Error processing enum: %s", type_model
)
elif base_type is model.ConstraintInteger:
typ = self.primitive_type_mapping[
model.ConstraintInteger
]()
elif base_type is model.ConstraintNumber:
typ = self.primitive_type_mapping[model.ConstraintNumber]()
elif base_type is model.PrimitiveBoolean:
typ = self.primitive_type_mapping[model.PrimitiveBoolean]()
if not typ:
raise RuntimeError(
f"Cannot map model type {type_model.__class__.__name__} to"
f" Rust type [{type_model}]"
)
if not model_ref:
model_ref = model.Reference(
name=self.root_name, type=typ.__class__
)
self.refs[model_ref] = typ
return typ
def _get_array_type(self, type_model: model.Array) -> Array:
"""Convert `model.Array` into corresponding Rust SDK model"""
return self.data_type_mapping[model.Array](
name=self.get_model_name(type_model.reference),
item_type=self.convert_model(type_model.item_type),
)
def _get_one_of_type(
self, type_model: model.OneOfType
) -> BaseCompoundType | BaseCombinedType | BasePrimitiveType:
"""Convert `model.OneOfType` into Rust model"""
kinds: list[dict] = []
is_nullable: bool = False
result_data_type = None
for kind in type_model.kinds:
if isinstance(kind, model.PrimitiveNull):
# Remove null from candidates and instead wrap with Option
is_nullable = True
continue
kind_type = self.convert_model(kind)
is_type_already_present = False
for processed_kind_type in kinds:
if (
isinstance(kind_type, BasePrimitiveType)
and processed_kind_type["local"] == kind_type
):
logging.debug(
"Simplifying oneOf with same mapped type %s [%s]",
kind,
type_model,
)
is_type_already_present = True
break
if not is_type_already_present:
kinds.append(
{
"model": kind,
"local": kind_type,
"class": kind_type.__class__,
}
)
# Simplify certain oneOf combinations
self._simplify_oneof_combinations(type_model, kinds)
if len(kinds) == 2:
list_type = [
x["local"]
for x in kinds
if x["class"] == self.data_type_mapping[model.Array]
]
if list_type:
lt: BaseCombinedType = list_type[0]
# Typ + list[Typ] => Vec<Typ>
item_type = [
x["local"]
for x in kinds
if x["class"] != self.data_type_mapping[model.Array]
][0]
if item_type.__class__ == lt.item_type.__class__:
result_data_type = self.data_type_mapping[model.Array](
item_type=item_type,
description=sanitize_rust_docstrings(
type_model.description
),
)
# logging.debug("Replacing Typ + list[Typ] with list[Typ]")
elif len(kinds) == 1:
result_data_type = kinds[0]["local"]
if not result_data_type:
enum_class = self.data_type_mapping[model.Enum]
result_data_type = enum_class(
name=self.get_model_name(type_model.reference), kinds={}
)
cnt: int = 0
for kind_data in kinds:
cnt += 1
kind_data_type = kind_data["local"]
kind_description: str | None = None
if isinstance(kind_data["model"], model.ADT):
kind_name = self.get_model_name(kind_data["model"])
kind_description = kind_data["model"].description
else:
kind_name = f"F{cnt}"
enum_kind = enum_class._kind_type_class(
name=kind_name,
description=sanitize_rust_docstrings(kind_description),
data_type=kind_data_type,
)
result_data_type.kinds[enum_kind.name] = enum_kind
if is_nullable:
result_data_type = self.option_type_class(
item_type=result_data_type
)
return result_data_type
def _get_struct_type(self, type_model: model.Struct) -> Struct:
"""Convert model.Struct into Rust `Struct`"""
struct_class = self.data_type_mapping[model.Struct]
mod = struct_class(
name=self.get_model_name(type_model.reference),
description=sanitize_rust_docstrings(type_model.description),
)
field_class = mod.field_type_class_
for field_name, field in type_model.fields.items():
is_nullable: bool = False
field_data_type = self.convert_model(field.data_type)
if isinstance(field_data_type, self.option_type_class):
# Unwrap Option into "is_nullable" NOTE: but perhaps
# Option<Option> is better (not set vs set explicitly to None
# )
is_nullable = True
if isinstance(field_data_type.item_type, Array):
# Unwrap Option<Option<Vec...>>
field_data_type = field_data_type.item_type
f = field_class(
local_name=self.get_local_attribute_name(field_name),
remote_name=self.get_remote_attribute_name(field_name),
description=sanitize_rust_docstrings(field.description),
data_type=field_data_type,
is_optional=not field.is_required,
is_nullable=is_nullable,
)
mod.fields[field_name] = f
if type_model.additional_fields:
definition = type_model.additional_fields
# Structure allows additional fields
if isinstance(definition, bool):
mod.additional_fields_type = self.primitive_type_mapping[
model.PrimitiveAny
]
else:
mod.additional_fields_type = self.convert_model(definition)
return mod
def _simplify_oneof_combinations(self, type_model, kinds):
"""Simplify certain known oneOf combinations"""
kinds_classes = [x["class"] for x in kinds]
string_klass = self.primitive_type_mapping[model.ConstraintString]
number_klass = self.primitive_type_mapping[model.ConstraintNumber]
integer_klass = self.primitive_type_mapping[model.ConstraintInteger]
boolean_klass = self.primitive_type_mapping[model.PrimitiveBoolean]
dict_klass = self.data_type_mapping[model.Dictionary]
option_klass = self.option_type_class
enum_name = type_model.reference.name if type_model.reference else None
if string_klass in kinds_classes and number_klass in kinds_classes:
# oneOf [string, number] => string
for typ in list(kinds):
if typ["class"] == number_klass:
kinds.remove(typ)
elif string_klass in kinds_classes and integer_klass in kinds_classes:
int_klass = next(
x
for x in type_model.kinds
if isinstance(x, model.ConstraintInteger)
)
if (
# XX_size or XX_count is clearly an integer
(
enum_name
and (
enum_name.endswith("size")
or enum_name.endswith("count")
)
)
# There is certain limit (min/max) - it can be only integer
or (
int_klass
and (
int_klass.minimum is not None
or int_klass.maximum is not None
)
)
):
for typ in list(kinds):
if typ["class"] == string_klass:
kinds.remove(typ)
else:
# oneOf [string, integer] => string
# Reason: compute.server.flavorRef is string or integer. For
# simplicity keep string
for typ in list(kinds):
if typ["class"] == integer_klass:
kinds.remove(typ)
elif string_klass in kinds_classes and boolean_klass in kinds_classes:
# oneOf [string, boolean] => boolean
for typ in list(kinds):
if typ["class"] == string_klass:
kinds.remove(typ)
elif string_klass in kinds_classes and dict_klass in kinds_classes:
# oneOf [string, dummy object] => JsonValue
# Simple string can be easily represented by JsonValue
for c in kinds:
# Discard dict
self.ignored_models.append(c["model"])
kinds.clear()
jsonval_klass = self.primitive_type_mapping[model.PrimitiveAny]
kinds.append({"local": jsonval_klass(), "class": jsonval_klass})
elif len(set(kinds_classes)) == 1 and string_klass in kinds_classes:
# in the output oneOf of same type (but maybe different formats)
# makes no sense
# Example is server addresses which are ipv4 or ipv6
bck = kinds[0].copy()
kinds.clear()
kinds.append(bck)
elif (
self.string_enum_class in kinds_classes
and option_klass in kinds_classes
):
option = next(x for x in kinds if isinstance(x["local"], Option))
enum = next(x for x in kinds if isinstance(x["local"], StringEnum))
if option and isinstance(option["local"].item_type, String):
enum["local"].allows_arbitrary_value = True
kinds.remove(option)
def set_models(self, models):
"""Process (translate) ADT models into Rust models"""
self.models = models
self.refs = {}
self.ignored_models = []
# A dictionary of model names to references to assign unique names
unique_models: dict[str, model.Reference] = {}
# iterate over all incoming models
for model_ in models:
# convert ADT based model into rust saving the result under self.refs
model_data_type = self.convert_model(model_)
# post process conversion results
if not isinstance(model_data_type, BaseCompoundType):
continue
name = getattr(model_data_type, "name", None)
if (
name
and model_.reference
and (
(
name in unique_models
and (
unique_models[name].hash_ != model_.reference.hash_
# maybe we previously decided to rename it
or self.refs[unique_models[name]].name != name
)
)
or name == self.root_name
)
):
# There is already a model with this name (i.e.
# SessionPersistence.type vs HealthMonitor.type).
if model_.reference and model_.reference.parent:
# Try adding parent_name as prefix
new_name = (
"".join(
x.title()
for x in model_.reference.parent.name.split("_")
)
+ name
)
else:
# Try adding suffix from datatype name
new_name = name + model_data_type.__class__.__name__
logging.debug(f"try renaming {name} to {new_name}")
if new_name not in unique_models:
# New name is still unused
model_data_type.name = new_name
self.refs[model_.reference].name = new_name
unique_models[new_name] = model_.reference
# rename original model to the same naming scheme
other_model = unique_models.get(name)
# Rename all other already processed models with name type and hash matching current model or the other model
for ref, some_model in self.refs.items():
if (
hasattr(some_model, "name")
and some_model.name == name
):
if (
ref.type == model_.reference.type
and ref.parent
and ref.parent.name
and (
(
other_model
and ref.hash_ == other_model.hash_
)
or ref.hash_ == model_.reference.hash_
)
):
new_other_name = (
"".join(
x.title()
for x in ref.parent.name.split("_")
)
+ name
)
some_model.name = new_other_name
unique_models[new_other_name] = some_model
logging.debug(
f"Renaming also {some_model} into"
f" {new_other_name} for consistency"
)
else:
if model_.reference.hash_ == unique_models[new_name].hash_:
if name != self.refs[unique_models[name]].name:
logging.debug(
"Found that same model"
f" {model_.reference} that we previously"
" renamed to"
f" {self.refs[unique_models[name]].name}"
)
pass
# not sure whether the new name should be save
# somewhere to be properly used in cli
# self.ignored_models.append(model_.reference)
elif isinstance(model_data_type, Struct):
# This is already an exceptional case (identity.mapping
# with remote being oneOf with multiple structs)
# Try to make a name consisting of props
props = model_data_type.fields.keys()
new_new_name = name + "".join(
x.title() for x in props
).replace("_", "")
if new_new_name not in unique_models:
for other_ref, other_model in self.refs.items():
other_name = getattr(other_model, "name", None)
if not other_name:
continue
if other_name in [
name,
new_name,
] and isinstance(other_model, Struct):
# rename first occurence to the same scheme
props = other_model.fields.keys()
new_other_name = name + "".join(
x.title() for x in props
).replace("_", "")
other_model.name = new_other_name
unique_models[new_other_name] = (
model_.reference
)
# unique_models.pop(new_name, None)
model_data_type.name = new_new_name
unique_models[new_new_name] = model_.reference
else:
raise RuntimeError(
f"Model name {new_new_name} is already present"
)
else:
raise RuntimeError(
f"Model name {new_name} is already present as"
f" {type(model_data_type)}"
)
elif (
name
and name in unique_models
and model_.reference
and unique_models[name].hash_ == model_.reference.hash_
# image.metadef.namespace have weird occurences of itself
and model_.reference != unique_models[name]
):
# We already have literally same model. Do nothing expecting
# that filtering in the `get_subtypes` will do the rest.
pass
elif name:
unique_models[name] = model_.reference
for ignore_model in self.ignored_models:
self.discard_model(ignore_model)
def get_subtypes(self):
"""Get all subtypes excluding TLA"""
# Need to prevent literaly same objects to be emitted multiple times
# what may happen in case of deep nesting
emitted: set[str] = set()
for k, v in self.refs.items():
if (
k
and isinstance(v, (Enum, Struct, StringEnum))
and k.name != self.root_name
):
key = f"{k.type}:{getattr(v, 'name', '')}:{k.hash_}"
if key not in emitted:
emitted.add(key)
yield v
elif (
k
and k.name != self.root_name
and isinstance(v, self.option_type_class)
):
if isinstance(v.item_type, Enum):
key = f"{v.item_type}:{getattr(v, 'name', '')}:{k.hash_}"
if key not in emitted:
emitted.add(key)
yield v.item_type
def get_root_data_type(self):
"""Get TLA type"""
for k, v in self.refs.items():
if not k or (k.name == self.root_name and isinstance(v, Struct)):
if isinstance(v.fields, dict):
# There might be tuple Struct (with
# fields as list)
field_names = list(v.fields.keys())
if (
len(field_names) == 1
and v.fields[field_names[0]].is_optional
):
# A body with only field can not normally be optional
logging.warning(
"Request body with single root field cannot be"
" optional"
)
v.fields[field_names[0]].is_optional = False
return v
elif not k or (
k.name == self.root_name and isinstance(v, Dictionary)
):
# Response is a free style Dictionary
return v
# No root has been found, make a dummy one
root = self.data_type_mapping[model.Struct](name="Request")
return root
def get_imports(self):
"""Get complete set of additional imports required by all models in scope"""
imports: set[str] = set()
imports.update(self.get_root_data_type().imports)
for subt in self.get_subtypes():
imports.update(subt.imports)
# for item in self.refs.values():
# imports.update(item.imports)
for param in self.parameters.values():
imports.update(param.data_type.imports)
return imports
def get_request_static_lifetimes(self, request_model: Struct):
"""Return static lifetimes of the Structure"""
lifetimes = request_model.lifetimes
for param in self.parameters.values():
if param.location == "header":
continue
lt = param.lifetimes
if lt:
lifetimes.update(lt)
if lifetimes:
return f"<{', '.join(lifetimes)}>"
return ""
def subtype_requires_private_builders(self, subtype) -> bool:
"""Return `True` if type require private builder"""
if not isinstance(subtype, self.data_type_mapping[model.Struct]):
return False
for field in subtype.fields.values():
if "private" in field.builder_macros:
return True
if isinstance(subtype, Struct) and subtype.additional_fields_type:
return True
return False
def set_parameters(self, parameters: list[model.RequestParameter]) -> None:
"""Set OpenAPI operation parameters into typemanager for conversion"""
for parameter in parameters:
data_type = self.convert_model(parameter.data_type)
param = self.request_parameter_class(
remote_name=self.get_remote_attribute_name(parameter.name),
local_name=self.get_local_attribute_name(parameter.name),
data_type=data_type,
location=parameter.location,
description=sanitize_rust_docstrings(parameter.description),
is_required=parameter.is_required,
is_flag=parameter.is_flag,
resource_link=parameter.resource_link,
)
if param.local_name in self.parameters:
raise RuntimeError(
f"Parameter with the name {param.local_name} is already"
" present"
)
self.parameters[param.local_name] = param
def get_parameters(
self, location: str
) -> Generator[Tuple[str, Type[RequestParameter]], None, None]:
"""Get parameters by location"""
for k, v in self.parameters.items():
if v.location == location:
yield (k, v)
def discard_model(
self, type_model: model.PrimitiveType | model.ADT | model.Reference
):
"""Discard model from the manager"""
logging.debug(f"Request to discard {type_model}")
if isinstance(type_model, model.Reference):
type_model = self._get_adt_by_reference(type_model)
if not hasattr(type_model, "reference"):
return
for ref, data in list(self.refs.items()):
if ref == type_model.reference:
sub_ref: model.Reference | None = None
if ref.type == model.Struct:
logging.debug(
"Element is a struct. Purging also field types"
)
# For struct type we cascadely discard all field types as
# well
for v in type_model.fields.values():
if isinstance(v.data_type, model.Reference):
sub_ref = v.data_type
else:
sub_ref = getattr(v.data_type, "reference", None)
if sub_ref:
logging.debug(f"Need to purge also {sub_ref}")
self.discard_model(sub_ref)
elif ref.type == model.OneOfType:
logging.debug(
"Element is a OneOf. Purging also kinds types"
)
for v in type_model.kinds:
if isinstance(v, model.Reference):
sub_ref = v
else:
sub_ref = getattr(v, "reference", None)
if sub_ref:
logging.debug(f"Need to purge also {sub_ref}")
self.discard_model(sub_ref)
elif ref.type == model.Array:
logging.debug(
"Element is a Array. Purging also item type"
f" {type_model.item_type}"
)
if isinstance(type_model.item_type, model.Reference):
sub_ref = type_model.item_type
else:
sub_ref = getattr(
type_model.item_type, "reference", None
)
if sub_ref:
logging.debug(f"Need to purge also {sub_ref}")
self.discard_model(sub_ref)
logging.debug(f"Purging {ref} from models")
self.refs.pop(ref, None)
def is_operation_supporting_params(self) -> bool:
"""Determine whether operation supports any sort of parameters"""
if self.parameters:
return True
root = self.get_root_data_type()
if (
root
and isinstance(root, Struct)
and not root.fields
and not root.additional_fields_type
):
return False
elif root:
return True
return False
def sanitize_rust_docstrings(doc: str | list[str] | None) -> str | None:
"""Sanitize the string to be a valid rust docstring"""
if not doc:
return None
code_block_open: bool = False
lines: list[str] = []
doc_lines: list[str] = []
if isinstance(doc, list):
for line in doc:
doc_lines.extend(line.split("\n"))
else:
doc_lines = doc.split("\n")
for line in doc_lines:
m = re.search(CODEBLOCK_RE, line)
if m and m.groups():
if not code_block_open:
code_block_open = True
# Rustdoc defaults to rust code for code blocks. To prevent
# this explicitly add `text`
if m.group(1) == "":
line = line + "text"
else:
code_block_open = False
lines.append(line)
return "\n".join(lines)