Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 67 additions & 13 deletions src/flask_ml/flask_ml_cli/MLCli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
TextInput,
EnumParameterDescriptor,
TextParameterDescriptor,
FileResponse,
DirectoryResponse,
MarkdownResponse,
TextResponse,
BatchFileResponse,
BatchTextResponse,
BatchDirectoryResponse,
)


Expand All @@ -48,7 +55,7 @@ def get_input_argument_validator_func(input_type: InputType):
return str
case InputType.BATCHDIRECTORY:
return is_path_exists_or_creatable_portable_arg_parser
case _: # pragma: no cover
case _: # pragma: no cover
assert_never(input_type)


Expand All @@ -66,7 +73,7 @@ def get_parameter_argument_validator_func(parameter_schema: ParameterSchema):
return get_int_range_check_func_arg_parser(parameter_schema.value.range)
case IntParameterDescriptor():
return int
case _: # pragma: no cover
case _: # pragma: no cover
assert_never(parameter_schema.value)


Expand Down Expand Up @@ -118,9 +125,7 @@ def _add_parameter_argument_to_parser(parser: ArgumentParser, parameter_schema:
default=default_param_value,
type=get_parameter_argument_validator_func(parameter_schema),
choices=(
get_enum_parameter_choices(parameter_schema)
if parameter_type == ParameterType.ENUM
else None
get_enum_parameter_choices(parameter_schema) if parameter_type == ParameterType.ENUM else None
),
)

Expand Down Expand Up @@ -148,7 +153,7 @@ def func(args):
input_model = BatchDirectoryInput(
directories=[DirectoryInput(path=item) for item in cli_input]
)
case _: # pragma: no cover
case _: # pragma: no cover
assert_never(input_schema.input_type)
inputs[input_schema.key] = input_model
for parameter_schema in task_schema.parameters:
Expand All @@ -174,7 +179,51 @@ def _add_subparser(self, subparsers, endpoint: EndpointDetails):
for parameter_schema in parameter_schemas:
self._add_parameter_argument_to_parser(subcommand_parser, parameter_schema)
self._set_function_on_parser(subcommand_parser, task_schema, endpoint.func)


@staticmethod
def _print_response_body(response: ResponseBody):
response_model = response.root
print("\tResults:")
match response_model:
case BatchTextResponse():
for text_response in response_model.texts:
if text_response.title:
print(f"\t{text_response.title}")
print()
print(f"\t{text_response.value}")
case BatchFileResponse():
for file_response in response_model.files:
if file_response.title:
print(f"\t{file_response.title}")
print(f"\tFile Type: {file_response.file_type.value}")
print(f"\tPath: {file_response.path}")
print()
case BatchDirectoryResponse():
for directory_response in response_model.directories:
if directory_response.title:
print(f"\t{directory_response.title}")
print(f"\tPath: {directory_response.path}")
print()
case _:
if response_model.title:
print(f"\t{response_model.title}")
match response_model:
case TextResponse():
print(response_model.value.replace("\n", "\n\t"))
print()
case FileResponse():
print(f"\tFile Type: {response_model.file_type.value}")
print(f"\tPath: {response_model.path}")
print()
case DirectoryResponse():
print(f"\tPath: {response_model.path}")
print()
case MarkdownResponse():
print(response_model.value.replace("\n", "\n\t"))
print()
case _: # pragma: no cover
assert_never(response_model)

def _setup_cli(self):
schema_endpoints = [
endpoint for endpoint in self._server.endpoints if isinstance(endpoint, EndpointDetails)
Expand All @@ -188,21 +237,26 @@ def _setup_cli(self):
subparsers = self._parser.add_subparsers(help="Subcommands", required=True)
for endpoint in schema_endpoints:
self._add_subparser(subparsers, endpoint)

def _parse_args(self, args: Sequence[str] | None = None):
parsed_args = self._parser.parse_args(args)
return parsed_args

def _run_cli_and_return(self, parsed_args: Namespace) -> Optional[ResponseBody]:
def _run_cli_and_return(
self, parsed_args: Namespace, print_response: bool = True
) -> Optional[ResponseBody]:
if parsed_args.func:
return parsed_args.func(parsed_args)
raise SystemExit("FATAL: No function defined") # pragma: no cover
response = parsed_args.func(parsed_args)
if print_response:
self._print_response_body(response)
return response
raise SystemExit("FATAL: No function defined") # pragma: no cover

def run_cli(self, args: Sequence[str] | None = None):
self._setup_cli()
parsed_args = self._parse_args(args)
response_body = self._run_cli_and_return(parsed_args)
if response_body is None:
raise SystemExit("FATAL: No response body") # pragma: no cover
raise SystemExit("FATAL: No response body") # pragma: no cover
print()
print(json.dumps(json.loads(response_body.model_dump_json()), indent=4))
self._print_response_body(response_body)
15 changes: 15 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from flask_ml.flask_ml_server import MLServer
import pytest

from flask_ml.flask_ml_server.models import MarkdownResponse, ResponseBody


@pytest.fixture
def ml_cli(server: MLServer):
Expand Down Expand Up @@ -316,3 +318,16 @@ def test_process_directories_and_ranged_int_parameter_with_schema_non_int_parame
]
)
ml_cli._run_cli_and_return(parsed_args)

def test_print_markdown_response(ml_cli: MLCli):
ml_cli._print_response_body(
ResponseBody(
root=MarkdownResponse(
title="Markdown Title",
subtitle="Markdown Subtitle",
value=f"""
# Markdown Content
""",
)
)
)