You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
autoflow-server-mgmt/src/main/java/kr/re/etri/autoflow/controllers/ExperimentsController.java

264 lines
13 KiB

package kr.re.etri.autoflow.controllers;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import kr.re.etri.autoflow.entity.ExperimentsEntity;
import kr.re.etri.autoflow.payload.request.ProjectBaseSearchRequest;
import kr.re.etri.autoflow.service.ExperimentsService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Tag(name = "Experiments", description = "Kubeflow 및 MLflow Experiment API")
@RestController
@RequestMapping("/api/experiments")
@RequiredArgsConstructor
@Slf4j
public class ExperimentsController {
private final ExperimentsService experimentsService;
private final WebClient.Builder webClientBuilder;
@Value("${kubeflow.url}")
private String kubeflowBaseUrl; // 예: http://192.168.10.135:32473/
@Value("${mlflow.url}")
private String mlflowBaseUrl; // 예: http://192.168.10.135:32473/
@Value("${mlflow.user}")
private String mlflowUser; // 예: http://192.168.10.135:32473/
@Value("${mlflow.password}")
private String mlflowPassword; // 예: http://192.168.10.135:32473/
@Operation(summary = "모든 Experiments 조회")
@GetMapping
public ResponseEntity<List<ExperimentsEntity>> getAllExperiments() {
return ResponseEntity.ok(experimentsService.findAll());
}
@Operation(summary = "Experiment 단건 조회")
@GetMapping("/{id}")
public ResponseEntity<ExperimentsEntity> getExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id) {
return experimentsService.findById(id)
.map(ResponseEntity::ok)
.orElse(ResponseEntity.notFound().build());
}
@Operation(summary = "Experiment 검색 및 페이지네이션")
@GetMapping("/search")
public ResponseEntity<Page<ExperimentsEntity>> searchExperiments(
@ParameterObject @ModelAttribute ProjectBaseSearchRequest request) {
Page<ExperimentsEntity> page = experimentsService.search(request);
return ResponseEntity.ok(page);
}
// @Operation(summary = "Experiment 등록")
// @PostMapping
// public Mono<ResponseEntity<ExperimentsEntity>> createExperiment(@RequestBody ExperimentsEntity experiment) {
//
// // 1⃣ DB 저장
// ExperimentsEntity saved = experimentsService.save(experiment);
//
// // 2⃣ Kubeflow POST 요청 payload
// Map<String, Object> payload = new HashMap<>();
// payload.put("display_name", saved.getDisplayName());
// payload.put("description", saved.getDescription());
// payload.put("namespace", "default"); // 필요에 따라 변경
//
// // 3⃣ WebClient POST
// return webClientBuilder.build()
// .post()
// .uri(kubeflowBaseUrl + "/apis/v2beta1/experiments")
// .contentType(MediaType.APPLICATION_JSON)
// .bodyValue(payload)
// .retrieve()
// .bodyToMono(Map.class) // Kubeflow 응답
// .map(resp -> {
// // resp에서 필요한 값 추출 후 entity에 반영
// if(resp.get("id") != null) {
// saved.setKubeFlowId(resp.get("id").toString());
// }
//
// if (resp.get("storage_state") != null) {
// saved.setKubeflowStorageState(resp.get("storage_state").toString());
// }
//
// if (resp.get("created_at") != null) {
// String lastRunStr = resp.get("created_at").toString();
// OffsetDateTime odt = OffsetDateTime.parse(lastRunStr);
// saved.setKubeflowCreatedAt(odt.withOffsetSameInstant(ZoneId.of("Asia/Seoul").getRules().getOffset(odt.toInstant()))
// .toLocalDateTime());
// }
//
// if (resp.get("last_run_created_at") != null) {
// String lastRunStr = resp.get("last_run_created_at").toString();
// OffsetDateTime odt = OffsetDateTime.parse(lastRunStr);
// saved.setKubeflowLastRunCreatedAt(odt.withOffsetSameInstant(ZoneId.of("Asia/Seoul").getRules().getOffset(odt.toInstant()))
// .toLocalDateTime());
// }
// return saved;
// })
// .map(ResponseEntity::ok)
// .doOnError(e -> log.error("Kubeflow experiment 등록 실패", e));
// }
@Operation(summary = "Experiment 등록 (Kubeflow + MLflow)")
@PostMapping
public Mono<ResponseEntity<ExperimentsEntity>> createExperiment(@RequestBody ExperimentsEntity experiment) {
ExperimentsEntity saved = experimentsService.save(experiment);
Map<String, Object> kubeflowPayload = new HashMap<>();
kubeflowPayload.put("display_name", saved.getDisplayName());
kubeflowPayload.put("description", saved.getDescription());
kubeflowPayload.put("namespace", "default");
return webClientBuilder.build()
// 1⃣ Kubeflow 등록
.post()
.uri(kubeflowBaseUrl + "/apis/v2beta1/experiments")
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(kubeflowPayload)
.retrieve()
.bodyToMono(Map.class)
.flatMap(kubeflowResp -> {
// Kubeflow 응답 처리
if (kubeflowResp.containsKey("experiment_id")) {
saved.setKubeFlowId((String) kubeflowResp.get("experiment_id"));
}
if (kubeflowResp.containsKey("created_at")) {
saved.setKubeflowCreatedAt(
Instant.parse((String) kubeflowResp.get("created_at"))
.atZone(ZoneId.of("Asia/Seoul"))
.toLocalDateTime()
);
}
if (kubeflowResp.containsKey("last_run_created_at")) {
saved.setKubeflowLastRunCreatedAt(
Instant.parse((String) kubeflowResp.get("last_run_created_at"))
.atZone(ZoneId.of("Asia/Seoul"))
.toLocalDateTime()
);
}
if (kubeflowResp.containsKey("storage_state")) {
saved.setKubeflowStorageState((String) kubeflowResp.get("storage_state"));
}
experimentsService.save(saved);
log.info("Kubeflow experiment 등록 완료: {}", kubeflowResp);
// 2⃣ MLflow 등록
Map<String, Object> mlflowPayload = new HashMap<>();
mlflowPayload.put("name", saved.getDisplayName());
mlflowPayload.put("artifact_location", "/default/artifacts");
return webClientBuilder.build()
.post()
.uri(mlflowBaseUrl + "/ajax-api/2.0/mlflow/experiments/create")
.headers(headers -> headers.setBasicAuth(mlflowUser, mlflowPassword))
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(mlflowPayload)
.retrieve()
.bodyToMono(Map.class)
.flatMap(createResp -> {
log.info("MLflow experiment 등록 완료: {}", createResp);
String mlflowExpId = (String) createResp.get("experiment_id");
// 3⃣ MLflow 조회
return webClientBuilder
.build()
.get()
.uri(mlflowBaseUrl + "/ajax-api/2.0/mlflow/experiments/get?experiment_id=" + mlflowExpId)
.headers(headers -> headers.setBasicAuth(mlflowUser, mlflowPassword))
.retrieve()
.bodyToMono(Map.class)
.flatMap(getResp -> {
log.info("MLflow experiment 상세 조회 완료: {}", getResp);
// 필요한 필드를 entity에 반영
if (getResp.containsKey("experiment")) {
Map<String, Object> exp = (Map<String, Object>) getResp.get("experiment");
// MLflow 응답 필드 반영
if (exp.containsKey("experiment_id")) {
saved.setMlFlowId((String) exp.get("experiment_id"));
}
if (exp.containsKey("artifact_location")) {
saved.setMlflow_artifactLocation((String) exp.get("artifact_location"));
}
if (exp.containsKey("lifecycle_stage")) {
saved.setMlflowLifecycleStage((String) exp.get("lifecycle_stage"));
}
if (exp.containsKey("created_at")) {
// created_at은 timestamp(ms)로 올 수도 있으므로 Instant 변환
Object createdAtObj = exp.get("created_at");
Instant createdAtInstant = null;
if (createdAtObj instanceof Number) {
createdAtInstant = Instant.ofEpochMilli(((Number) createdAtObj).longValue());
} else if (createdAtObj instanceof String) {
createdAtInstant = Instant.parse((String) createdAtObj);
}
if (createdAtInstant != null) {
saved.setMlflowCreatedAt(
createdAtInstant.atZone(ZoneId.of("Asia/Seoul")).toLocalDateTime()
);
}
}
}
experimentsService.save(saved);
return Mono.just(ResponseEntity.ok(saved));
});
});
})
.doOnError(e -> log.error("Experiment 등록 실패", e));
}
@Operation(summary = "Experiment 수정")
@PutMapping("/{id}")
public ResponseEntity<ExperimentsEntity> updateExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id,
@RequestBody ExperimentsEntity experiment) {
return experimentsService.findById(id)
.map(existing -> {
experiment.setId(id);
return ResponseEntity.ok(experimentsService.save(experiment));
})
.orElse(ResponseEntity.notFound().build());
}
@Operation(summary = "Experiment 삭제")
@DeleteMapping("/{id}")
public ResponseEntity<Void> deleteExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id) {
if (experimentsService.deleteById(id)) {
return ResponseEntity.noContent().build();
}
return ResponseEntity.notFound().build();
}
}