Skip to content

Commit 63ba2e4

Browse files
authored
fix: Update API spec to implement core fixes, fix a edgecase. (#4)
* fix: Make a edgecase much nicer in the API spec. * fix: Update the API data based on the specification. * chore: Bump the types.
1 parent ff28376 commit 63ba2e4

File tree

5 files changed

+217
-30
lines changed

5 files changed

+217
-30
lines changed

autogen.py

Lines changed: 150 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
),
3232
]
3333

34+
# Endpoints that return bool based on HTTP status: 404 -> False, 2xx -> True, else raise.
35+
# Each entry is (METHOD, openapi_path_template).
36+
BOOLEAN_STATUS_ROUTES: list[tuple[str, str]] = [
37+
("GET", "/virtual_tag_configs/async/{request_id}"),
38+
]
39+
3440

3541
@dataclass
3642
class Parameter:
@@ -61,6 +67,7 @@ class Endpoint:
6167
is_multipart: bool = False
6268
response_handler: str | None = None # internal client method to call, if not the default
6369
response_handler_return_type: str | None = None
70+
boolean_status: bool = False # 404->False, 2xx->True, else raise VantageAPIError
6471

6572

6673
@dataclass
@@ -196,10 +203,16 @@ def preprocess_inline_models(schemas: dict[str, Any]) -> None:
196203
existing_names.add(model_name)
197204

198205

199-
def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> str:
206+
def openapi_type_to_python(
207+
schema: dict[str, Any],
208+
schemas: dict[str, Any],
209+
name_map: dict[str, str] | None = None,
210+
) -> str:
200211
"""Convert OpenAPI type to Python type hint."""
201212
if "$ref" in schema:
202213
ref_name = schema["$ref"].split("/")[-1]
214+
if name_map and ref_name in name_map:
215+
return name_map[ref_name]
203216
return to_pascal_case(ref_name)
204217

205218
schema_type = schema.get("type", "any")
@@ -216,12 +229,12 @@ def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> s
216229
return "bool"
217230
elif schema_type == "array":
218231
items = schema.get("items", {})
219-
item_type = openapi_type_to_python(items, schemas)
232+
item_type = openapi_type_to_python(items, schemas, name_map)
220233
return f"List[{item_type}]"
221234
elif schema_type == "object":
222235
additional = schema.get("additionalProperties")
223236
if additional:
224-
value_type = openapi_type_to_python(additional, schemas)
237+
value_type = openapi_type_to_python(additional, schemas, name_map)
225238
return f"Dict[str, {value_type}]"
226239
# Check if inline properties match an existing named schema
227240
inline_props = schema.get("properties")
@@ -230,14 +243,18 @@ def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> s
230243
for schema_name, schema_def in schemas.items():
231244
defined_keys = sorted(schema_def.get("properties", {}).keys())
232245
if defined_keys and inline_keys == defined_keys:
246+
if name_map and schema_name in name_map:
247+
return name_map[schema_name]
233248
return to_pascal_case(schema_name)
234249
return "Dict[str, Any]"
235250
else:
236251
return "Any"
237252

238253

239254
def extract_request_body_type(
240-
request_body: dict[str, Any] | None, schemas: dict[str, Any]
255+
request_body: dict[str, Any] | None,
256+
schemas: dict[str, Any],
257+
name_map: dict[str, str] | None = None,
241258
) -> tuple[str | None, bool]:
242259
"""Extract request body type and whether it's multipart."""
243260
if not request_body:
@@ -252,13 +269,15 @@ def extract_request_body_type(
252269
# Check for JSON
253270
if "application/json" in content:
254271
schema = content["application/json"].get("schema", {})
255-
return openapi_type_to_python(schema, schemas), False
272+
return openapi_type_to_python(schema, schemas, name_map), False
256273

257274
return None, False
258275

259276

260277
def extract_response_type(
261-
responses: dict[str, Any], schemas: dict[str, Any]
278+
responses: dict[str, Any],
279+
schemas: dict[str, Any],
280+
name_map: dict[str, str] | None = None,
262281
) -> str | None:
263282
"""Extract successful response type."""
264283
for code in ["200", "201", "202", "203"]:
@@ -268,15 +287,59 @@ def extract_response_type(
268287
content = response.get("content", {})
269288
if "application/json" in content:
270289
schema = content["application/json"].get("schema", {})
271-
return openapi_type_to_python(schema, schemas)
290+
return openapi_type_to_python(schema, schemas, name_map)
272291
return None
273292

274293

294+
def find_request_body_schemas(schema: dict[str, Any]) -> set[str]:
295+
"""Return the set of schema names referenced as request bodies in any endpoint."""
296+
result = set()
297+
paths = schema.get("paths", {})
298+
for path_item in paths.values():
299+
for method, spec in path_item.items():
300+
if method in ("parameters", "servers", "summary", "description"):
301+
continue
302+
request_body = spec.get("requestBody", {})
303+
content = request_body.get("content", {})
304+
for media_type in content.values():
305+
ref_schema = media_type.get("schema", {})
306+
if "$ref" in ref_schema:
307+
name = ref_schema["$ref"].split("/")[-1]
308+
result.add(name)
309+
return result
310+
311+
312+
def build_class_name_map(schemas: dict[str, Any], request_body_schemas: set[str]) -> dict[str, str]:
313+
"""Build a mapping from raw schema names to Python class names, resolving conflicts.
314+
315+
If two schema names map to the same PascalCase name, the one used as a
316+
request body is suffixed with 'Request'.
317+
"""
318+
initial = {name: to_pascal_case(name) for name in schemas}
319+
320+
by_class_name: dict[str, list[str]] = {}
321+
for raw_name, class_name in initial.items():
322+
by_class_name.setdefault(class_name, []).append(raw_name)
323+
324+
result: dict[str, str] = {}
325+
for class_name, raw_names in by_class_name.items():
326+
if len(raw_names) == 1:
327+
result[raw_names[0]] = class_name
328+
else:
329+
for raw_name in raw_names:
330+
if raw_name in request_body_schemas:
331+
result[raw_name] = class_name + "Request"
332+
else:
333+
result[raw_name] = class_name
334+
return result
335+
336+
275337
def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
276338
"""Parse all endpoints from OpenAPI schema."""
277339
endpoints = []
278340
paths = schema.get("paths", {})
279341
schemas = schema.get("components", {}).get("schemas", {})
342+
name_map = build_class_name_map(schemas, find_request_body_schemas(schema))
280343

281344
for path, methods in paths.items():
282345
for method, spec in methods.items():
@@ -288,7 +351,7 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
288351
parameters = []
289352
for param in spec.get("parameters", []):
290353
param_schema = param.get("schema", {})
291-
param_type = openapi_type_to_python(param_schema, schemas)
354+
param_type = openapi_type_to_python(param_schema, schemas, name_map)
292355
parameters.append(
293356
Parameter(
294357
name=param["name"],
@@ -301,9 +364,9 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
301364
)
302365

303366
request_body = spec.get("requestBody")
304-
body_type, is_multipart = extract_request_body_type(request_body, schemas)
367+
body_type, is_multipart = extract_request_body_type(request_body, schemas, name_map)
305368

306-
response_type = extract_response_type(spec.get("responses", {}), schemas)
369+
response_type = extract_response_type(spec.get("responses", {}), schemas, name_map)
307370

308371
description = spec.get("description")
309372

@@ -319,6 +382,10 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
319382
if response_handler:
320383
break
321384

385+
boolean_status = (method.upper(), path) in {
386+
(m.upper(), p) for m, p in BOOLEAN_STATUS_ROUTES
387+
}
388+
322389
endpoints.append(
323390
Endpoint(
324391
path=path,
@@ -336,6 +403,7 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
336403
is_multipart=is_multipart,
337404
response_handler=response_handler,
338405
response_handler_return_type=response_handler_return_type,
406+
boolean_status=boolean_status,
339407
)
340408
)
341409

@@ -469,6 +537,7 @@ def _append_response_mapping(lines: list[str], return_type: str, data_var: str)
469537
def generate_pydantic_models(schema: dict[str, Any]) -> str:
470538
"""Generate Pydantic models from OpenAPI schemas."""
471539
schemas = schema.get("components", {}).get("schemas", {})
540+
name_map = build_class_name_map(schemas, find_request_body_schemas(schema))
472541
lines = [
473542
'"""Auto-generated Pydantic models from OpenAPI schema."""',
474543
"",
@@ -482,7 +551,7 @@ def generate_pydantic_models(schema: dict[str, Any]) -> str:
482551
]
483552

484553
for name, spec in schemas.items():
485-
class_name = to_pascal_case(name)
554+
class_name = name_map.get(name, to_pascal_case(name))
486555
description = spec.get("description", "")
487556

488557
lines.append(f"class {class_name}(BaseModel):")
@@ -508,7 +577,7 @@ def generate_pydantic_models(schema: dict[str, Any]) -> str:
508577
python_name = python_name + "_"
509578
needs_alias = True
510579

511-
prop_type = openapi_type_to_python(prop_spec, schemas)
580+
prop_type = openapi_type_to_python(prop_spec, schemas, name_map)
512581

513582
# Handle nullable
514583
if prop_spec.get("x-nullable") or prop_spec.get("nullable"):
@@ -564,6 +633,21 @@ def _collect_handler_routes(resources: dict[str, Resource]) -> dict[str, list[tu
564633
return handler_routes
565634

566635

636+
def _collect_boolean_status_prefixes(resources: dict[str, Resource]) -> list[tuple[str, str]]:
637+
"""Collect (method, path_prefix) pairs for boolean-status endpoints.
638+
639+
The prefix is derived by taking everything before the first path parameter
640+
so it can be matched with str.startswith() at runtime.
641+
"""
642+
result = []
643+
for resource in resources.values():
644+
for endpoint in resource.endpoints:
645+
if endpoint.boolean_status:
646+
prefix = endpoint.path.split("{")[0]
647+
result.append((endpoint.method, prefix))
648+
return result
649+
650+
567651
def generate_sync_client(resources: dict[str, Resource]) -> str:
568652
"""Generate synchronous client code."""
569653
lines = [
@@ -652,6 +736,29 @@ def generate_sync_client(resources: dict[str, Resource]) -> str:
652736
" json=body,",
653737
" )",
654738
"",
739+
]
740+
)
741+
742+
# Inject boolean-status path checks (before the generic error check)
743+
boolean_prefixes = _collect_boolean_status_prefixes(resources)
744+
for method, prefix in boolean_prefixes:
745+
lines.extend([
746+
f' if method.upper() == "{method}" and path.startswith("{prefix}"):',
747+
" if response.status_code == 404:",
748+
" return False",
749+
" elif response.is_success:",
750+
" return True",
751+
" else:",
752+
" raise VantageAPIError(",
753+
" status=response.status_code,",
754+
" status_text=response.reason_phrase,",
755+
" body=response.text,",
756+
" )",
757+
"",
758+
])
759+
760+
lines.extend(
761+
[
655762
" if not response.is_success:",
656763
" raise VantageAPIError(",
657764
" status=response.status_code,",
@@ -757,7 +864,9 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:
757864

758865
# Method signature
759866
param_str = ", ".join(["self"] + params) if params else "self"
760-
if endpoint.response_handler:
867+
if endpoint.boolean_status:
868+
return_type = "bool"
869+
elif endpoint.response_handler:
761870
return_type = endpoint.response_handler_return_type or "Any"
762871
else:
763872
return_type = endpoint.response_type or "None"
@@ -801,7 +910,7 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:
801910
lines.append(" body_data = None")
802911

803912
# Make request and coerce response payload into typed models where possible
804-
if endpoint.response_handler:
913+
if endpoint.boolean_status or endpoint.response_handler:
805914
lines.append(
806915
f' return self._client.request("{endpoint.method}", path, params=params, body=body_data)'
807916
)
@@ -907,6 +1016,29 @@ def generate_async_client(resources: dict[str, Resource]) -> str:
9071016
" json=body,",
9081017
" )",
9091018
"",
1019+
]
1020+
)
1021+
1022+
# Inject boolean-status path checks (before the generic error check)
1023+
boolean_prefixes = _collect_boolean_status_prefixes(resources)
1024+
for method, prefix in boolean_prefixes:
1025+
lines.extend([
1026+
f' if method.upper() == "{method}" and path.startswith("{prefix}"):',
1027+
" if response.status_code == 404:",
1028+
" return False",
1029+
" elif response.is_success:",
1030+
" return True",
1031+
" else:",
1032+
" raise VantageAPIError(",
1033+
" status=response.status_code,",
1034+
" status_text=response.reason_phrase,",
1035+
" body=response.text,",
1036+
" )",
1037+
"",
1038+
])
1039+
1040+
lines.extend(
1041+
[
9101042
" if not response.is_success:",
9111043
" raise VantageAPIError(",
9121044
" status=response.status_code,",
@@ -1012,7 +1144,9 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:
10121144

10131145
# Method signature
10141146
param_str = ", ".join(["self"] + params) if params else "self"
1015-
if endpoint.response_handler:
1147+
if endpoint.boolean_status:
1148+
return_type = "bool"
1149+
elif endpoint.response_handler:
10161150
return_type = endpoint.response_handler_return_type or "Any"
10171151
else:
10181152
return_type = endpoint.response_type or "None"
@@ -1056,7 +1190,7 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:
10561190
lines.append(" body_data = None")
10571191

10581192
# Make request and coerce response payload into typed models where possible
1059-
if endpoint.response_handler:
1193+
if endpoint.boolean_status or endpoint.response_handler:
10601194
lines.append(
10611195
f' return await self._client.request("{endpoint.method}", path, params=params, body=body_data)'
10621196
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "vantage-python"
7-
version = "0.3.2"
7+
version = "0.3.3"
88
description = "Python SDK for the Vantage API"
99
readme = "README.md"
1010
license = "MIT"

0 commit comments

Comments
 (0)