You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
autoflow-server-mgmt/src/main/java/kr/re/etri/autoflow/controllers/MlflowController.java

153 lines
6.8 KiB

package kr.re.etri.autoflow.controllers;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.Map;
@Tag(name = "MLflow API", description = "MLflow Experiment 및 Run 조회 API")
@RestController
@RequestMapping("/api/mlflow")
public class MlflowController {
private final WebClient webClient;
public MlflowController() {
this.webClient = WebClient.builder()
.baseUrl("http://192.168.10.135:30128/api/2.0/mlflow")
.defaultHeaders(headers -> headers.setBasicAuth("user", "WjWjIi13KEkO"))
.build();
}
@Operation(
summary = "Experiment 조회",
description = "Experiment 이름으로 MLflow Experiment 정보를 조회합니다.",
responses = {
@ApiResponse(responseCode = "200", description = "Experiment 정보 조회 성공"),
@ApiResponse(responseCode = "404", description = "Experiment를 찾을 수 없음"),
@ApiResponse(responseCode = "500", description = "서버 오류 발생")
}
)
@GetMapping(value = "/experiment", produces = MediaType.APPLICATION_JSON_VALUE)
public Map<String, Object> getExperimentByName(
@Parameter(description = "조회할 Experiment 이름", required = true, example = "MyExperiment")
@RequestParam String experimentName) {
Map response = webClient.get()
.uri(uriBuilder -> uriBuilder
.path("/experiments/get-by-name")
.queryParam("experiment_name", experimentName)
.build())
.retrieve()
.bodyToMono(Map.class)
.block();
if (response == null || !response.containsKey("experiment")) {
throw new RuntimeException("Experiment not found: " + experimentName);
}
return (Map<String, Object>) response.get("experiment");
}
@Operation(
summary = "Run 단건 조회",
description = "주어진 Run ID의 상세 정보를 조회합니다. MLflow API `/runs/get`를 호출하여 Run 정보를 반환합니다.",
responses = {
@ApiResponse(responseCode = "200", description = "Run 정보 조회 성공"),
@ApiResponse(responseCode = "500", description = "서버 오류 발생")
}
)
@GetMapping(value = "/run", produces = MediaType.APPLICATION_JSON_VALUE)
public Mono<ResponseEntity<String>> getRun(
@Parameter(description = "조회할 Run ID", required = true, example = "59e4f75b29eb4354b9e9e2ec9d93e2e3")
@RequestParam String runId) {
String uri = String.format("/runs/get?run_id=%s", runId);
return webClient.get()
.uri(uri)
.retrieve()
.bodyToMono(String.class)
.map(ResponseEntity::ok)
.onErrorResume(e -> Mono.just(ResponseEntity.internalServerError().body(e.getMessage())));
}
@Operation(
summary = "Experiment의 Run 목록 조회",
description = "주어진 Experiment ID의 Run을 최대 1000개까지 조회합니다. MLflow API `/runs/search`를 호출합니다.",
responses = {
@ApiResponse(responseCode = "200", description = "Run 목록 조회 성공"),
@ApiResponse(responseCode = "500", description = "서버 오류 발생")
}
)
@GetMapping(value = "/runs", produces = MediaType.APPLICATION_JSON_VALUE)
public Mono<ResponseEntity<String>> getRuns(
@Parameter(description = "조회할 Experiment ID", required = true, example = "1234567890abcdef")
@RequestParam String experimentId) {
Map<String, Object> body = Map.of(
"experiment_ids", Collections.singletonList(experimentId),
"order_by", Collections.singletonList("attribute.start_time DESC"),
"max_results", 1000
);
return webClient.post()
.uri("/runs/search")
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(body)
.retrieve()
.bodyToMono(String.class)
.map(ResponseEntity::ok)
.onErrorResume(e -> Mono.just(ResponseEntity.internalServerError().body(e.getMessage())));
}
@Operation(
summary = "Run의 Artifact 목록 조회",
description = """
Run ID Artifact .
MLflow API `/artifacts/list` , `path` Artifact .
""",
responses = {
@ApiResponse(responseCode = "200", description = "Artifact 목록 조회 성공"),
@ApiResponse(responseCode = "404", description = "Run을 찾을 수 없음"),
@ApiResponse(responseCode = "500", description = "서버 오류 발생")
}
)
@GetMapping(value = "/artifacts", produces = MediaType.APPLICATION_JSON_VALUE)
public Mono<ResponseEntity<String>> listArtifacts(
@Parameter(description = "조회할 Run ID", required = true, example = "9d08fa7973cf4c39a0979bb4d70c640b")
@RequestParam String runId,
@Parameter(description = "조회할 경로 (선택)", example = "models")
@RequestParam(required = false) String path,
@Parameter(description = "페이징 토큰 (선택)", example = "MjAyNS0xMC0yN1QxMjo0NjozMlo=")
@RequestParam(required = false, name = "page_token") String pageToken
) {
return webClient.get()
.uri(uriBuilder -> {
var builder = uriBuilder.path("/artifacts/list")
.queryParam("run_id", runId);
if (path != null && !path.isBlank()) {
builder.queryParam("path", path);
}
if (pageToken != null && !pageToken.isBlank()) {
builder.queryParam("page_token", pageToken);
}
return builder.build();
})
.retrieve()
.bodyToMono(String.class)
.map(ResponseEntity::ok)
.onErrorResume(e -> Mono.just(ResponseEntity.internalServerError().body(e.getMessage())));
}
}