Generate TUI response structure

Change-Id: I2b9f37ebb5e08e36b969aee148bc571ed587c7ae
This commit is contained in:
Artem Goncharov
2025-01-12 16:44:34 +01:00
parent 1c3775d15c
commit 8800529ad4
2 changed files with 277 additions and 33 deletions

View File

@@ -20,11 +20,24 @@ from codegenerator.base import BaseGenerator
from codegenerator import common
from codegenerator import model
from codegenerator.common import BaseCompoundType
from codegenerator.common import BaseCombinedType
from codegenerator.common import BasePrimitiveType
from codegenerator.common import rust as common_rust
from codegenerator.rust_sdk import TypeManager as SdkTypeManager
from codegenerator import rust_sdk
BASIC_FIELDS = [
"name",
"title",
"created_at",
"updated_at",
"state",
"status",
"operating_status",
]
class String(common_rust.String):
type_hint: str = "String"
@@ -172,6 +185,83 @@ class Struct(rust_sdk.Struct):
return result
class StructFieldResponse(common_rust.StructField):
"""Response Structure Field"""
@property
def type_hint(self):
typ_hint = self.data_type.type_hint
if self.is_optional and not typ_hint.startswith("Option<"):
typ_hint = f"Option<{typ_hint}>"
return typ_hint
@property
def serde_macros(self):
macros = set()
if self.local_name != self.remote_name:
macros.add(f'rename="{self.remote_name}"')
if self.is_optional or self.data_type.type_hint.startswith("Option<"):
macros.add("default")
return f"#[serde({', '.join(sorted(macros))})]"
def get_structable_macros(
self,
struct: "StructResponse",
service_name: str,
resource_name: str,
operation_type: str,
):
macros = set()
if self.is_optional or self.data_type.type_hint.startswith("Option<"):
macros.add("optional")
macros.add(f'title="{self.remote_name.upper()}"')
# Fully Qualified Attribute Name
fqan: str = ".".join(
[service_name, resource_name, self.remote_name]
).lower()
# Check the known alias of the field by FQAN
alias = common.FQAN_ALIAS_MAP.get(fqan)
if operation_type in ["list", "list_from_struct"]:
if (
"id" in struct.fields.keys()
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
) or (
"id" not in struct.fields.keys()
and (self.local_name not in list(struct.fields.keys())[-10:])
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
):
# Only add "wide" flag if field is not in the basic fields AND
# there is at least "id" field existing in the struct OR the
# field is not in the first 10
macros.add("wide")
if (
self.local_name == "state"
and "status" not in struct.fields.keys()
):
macros.add("status")
elif (
self.local_name == "operating_status"
and "status" not in struct.fields.keys()
):
macros.add("status")
return f"#[structable({', '.join(sorted(macros))})]"
class StructResponse(common_rust.Struct):
field_type_class_: Type[common_rust.StructField] = StructFieldResponse
@property
def imports(self):
imports: set[str] = {"serde::Deserialize"}
for field in self.fields.values():
imports.update(field.data_type.imports)
return imports
class TypeManager(common_rust.TypeManager):
"""Rust SDK type manager
@@ -241,6 +331,112 @@ class TypeManager(common_rust.TypeManager):
yield (v.item_type, sdk_type)
class ResponseTypeManager(common_rust.TypeManager):
primitive_type_mapping: dict[
Type[model.PrimitiveType], Type[BasePrimitiveType]
] = {
model.PrimitiveString: common_rust.String,
model.ConstraintString: common_rust.String,
}
data_type_mapping = {
model.Struct: StructResponse,
model.Array: common_rust.JsonValue,
model.Dictionary: common_rust.JsonValue,
}
def get_model_name(self, model_ref: model.Reference | None) -> str:
"""Get the localized model type name
In order to avoid collision between structures in request and
response we prefix all types with `Response`
:returns str: Type name
"""
if not model_ref:
return "Response"
return "Response" + "".join(
x.capitalize()
for x in re.split(common.SPLIT_NAME_RE, model_ref.name)
)
def _get_struct_type(self, type_model: model.Struct) -> common_rust.Struct:
"""Convert model.Struct into Rust `Struct`"""
struct_class = self.data_type_mapping[model.Struct]
mod = struct_class(
name=self.get_model_name(type_model.reference),
description=common_rust.sanitize_rust_docstrings(
type_model.description
),
)
field_class = mod.field_type_class_
for field_name, field in type_model.fields.items():
is_nullable: bool = False
field_data_type = self.convert_model(field.data_type)
if isinstance(field_data_type, self.option_type_class):
# Unwrap Option into "is_nullable" NOTE: but perhaps
# Option<Option> is better (not set vs set explicitly to None
# )
is_nullable = True
if isinstance(field_data_type.item_type, common_rust.Array):
# Unwrap Option<Option<Vec...>>
field_data_type = field_data_type.item_type
elif not isinstance(field_data_type, BasePrimitiveType):
field_data_type = common_rust.JsonValue(
**field_data_type.model_dump()
)
self.ignored_models.append(field.data_type)
f = field_class(
local_name=self.get_local_attribute_name(field_name),
remote_name=self.get_remote_attribute_name(field_name),
description=common_rust.sanitize_rust_docstrings(
field.description
),
data_type=field_data_type,
is_optional=not field.is_required,
is_nullable=is_nullable,
)
mod.fields[field_name] = f
if type_model.additional_fields:
definition = type_model.additional_fields
# Structure allows additional fields
if isinstance(definition, bool):
mod.additional_fields_type = self.primitive_type_mapping[
model.PrimitiveAny
]
else:
mod.additional_fields_type = self.convert_model(definition)
return mod
def get_subtypes(self):
"""Get all subtypes excluding TLA"""
emited_data: set[str] = set()
for k, v in self.refs.items():
if (
k
and isinstance(
v,
(
common_rust.Enum,
common_rust.Struct,
common_rust.StringEnum,
common_rust.Dictionary,
common_rust.Array,
),
)
and k.name != "Body"
):
key = v.base_type + v.type_hint
if key not in emited_data:
emited_data.add(key)
yield v
def get_imports(self):
"""Get complete set of additional imports required by all models in scope"""
imports: set[str] = super().get_imports()
imports.discard("crate::common::parse_json")
return imports
class RustTuiGenerator(BaseGenerator):
def __init__(self):
super().__init__()
@@ -278,9 +474,9 @@ class RustTuiGenerator(BaseGenerator):
openapi_spec, operation_id
)
# srv_name, res_name = res.split(".") if res else (None, None)
# srv_name, resource_name = res.split(".") if res else (None, None)
path_resources = common.get_resource_names_from_url(path)
res_name = path_resources[-1]
resource_name = path_resources[-1]
mime_type = None
openapi_parser = model.OpenAPISchemaParser()
@@ -298,8 +494,8 @@ class RustTuiGenerator(BaseGenerator):
# Respect path params that appear in path and not path params
param_ = openapi_parser.parse_parameter(param)
if param_.name in [
f"{res_name}_id",
f"{res_name.replace('_', '')}_id",
f"{resource_name}_id",
f"{resource_name.replace('_', '')}_id",
]:
path = path.replace(param_.name, "id")
# for i.e. routers/{router_id} we want local_name to be `id` and not `router_id`
@@ -330,6 +526,7 @@ class RustTuiGenerator(BaseGenerator):
# TODO(gtema): previously we were ensuring `router_id` path param
# is renamed to `id`
additional_imports = set()
if api_ver_matches:
api_ver = {
"major": api_ver_matches.group(1),
@@ -353,10 +550,17 @@ class RustTuiGenerator(BaseGenerator):
class_name = f"{service_name}{''.join(x.title() for x in path_resources)}{operation_name}".replace(
"_", ""
)
response_class_name = f"{service_name}{''.join(x.title() for x in path_resources)}".replace(
"_", ""
)
operation_body = operation_variant.get("body")
type_manager = TypeManager()
sdk_type_manager = SdkTypeManager()
type_manager.set_parameters(operation_params)
response_type_manager: common_rust.TypeManager = (
ResponseTypeManager()
)
sdk_type_manager.set_parameters(operation_params)
mod_name = "_".join(
x.lower()
@@ -411,34 +615,49 @@ class RustTuiGenerator(BaseGenerator):
)
response_key: str | None = None
if args.response_key:
response_key = (
args.response_key if args.response_key != "null" else None
result_def: dict = {}
response_def: dict | None = {}
resource_header_metadata: dict = {}
# Get basic information about response
if args.operation_type == "list":
response = common.find_response_schema(
spec["responses"],
args.response_key or resource_name,
(
args.operation_name
if args.operation_type == "action"
else None
),
)
else:
# Get basic information about response
if method.upper() != "HEAD":
for code, rspec in spec["responses"].items():
if not code.startswith("2"):
continue
content = rspec.get("content", {})
if "application/json" in content:
response_spec = content["application/json"]
try:
(_, response_key) = (
common.find_resource_schema(
response_spec["schema"],
None,
res_name.lower(),
)
)
except Exception:
# Most likely we have response which is oneOf.
# For the SDK it does not really harm to ignore
# this.
pass
# response_def = (None,)
response_key = None
if response:
if args.response_key:
response_key = (
args.response_key
if args.response_key != "null"
else None
)
else:
response_key = resource_name
response_def, _ = common.find_resource_schema(
response, None, response_key
)
if response_def:
if response_def.get("type", "object") == "object" or (
# BS metadata is defined with type: ["object",
# "null"]
isinstance(response_def.get("type"), list)
and "object" in response_def["type"]
):
(root, response_types) = openapi_parser.parse(
response_def
)
response_type_manager.set_models(response_types)
additional_imports.add("serde_json::Value")
sdk_mod_path_base = [
"openstack_sdk",
"api",
@@ -449,13 +668,15 @@ class RustTuiGenerator(BaseGenerator):
mod_suffix: str = ""
sdk_mod_path.append((args.sdk_mod_name or mod_name) + mod_suffix)
additional_imports = set()
additional_imports.add(
"::".join(sdk_mod_path) + "::RequestBuilder"
)
additional_imports.add(
"openstack_sdk::{AsyncOpenStack, api::QueryAsync}"
)
additional_imports.add("structable_derive::StructTable")
additional_imports.add("crate::utils::StructTable")
additional_imports.add("crate::utils::OutputConfig")
if args.operation_type == "list":
if "limit" in [
k for (k, _) in type_manager.get_parameters("query")
@@ -480,8 +701,10 @@ class RustTuiGenerator(BaseGenerator):
common.make_ascii_string(spec.get("description"))
),
"class_name": class_name,
"response_class_name": response_class_name,
"sdk_service_name": service_name,
"resource_name": res_name,
"resource_name": resource_name,
"response_type_manager": response_type_manager,
"url": path.lstrip("/").lstrip(ver_prefix).lstrip("/"),
"method": method,
"type_manager": type_manager,

View File

@@ -215,3 +215,24 @@ impl ConfirmableRequest for {{ class_name }} {
{%- endif %}
{%- endwith %}
{%- with data_type = response_type_manager.get_root_data_type() %}
{%- if data_type.__class__.__name__ == "StructResponse" %}
{%- if data_type.fields %}
/// {{ response_class_name }} response representation
#[derive(Deserialize, Serialize)]
#[derive(Clone, StructTable)]
struct {{ response_class_name }} {
{%- for k, v in data_type.fields | dictsort %}
{% if not (operation_type == "list" and k in ["links"]) %}
{{ macros.docstring(v.description, indent=4) }}
{{ v.serde_macros }}
{{ v.get_structable_macros(data_type, sdk_service_name, resource_name, operation_type) }}
{{ v.local_name }}: {{ v.type_hint }},
{%- endif %}
{%- endfor %}
}
{%- endif %}
{%- endif %}
{%- endwith %}