You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
autoflow-server-mgmt/src/main/java/kr/re/etri/autoflow/controllers/ExperimentsController.java

237 lines
11 KiB

package kr.re.etri.autoflow.controllers;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.transaction.Transactional;
import kr.re.etri.autoflow.entity.ExperimentsEntity;
import kr.re.etri.autoflow.payload.request.ProjectBaseSearchRequest;
import kr.re.etri.autoflow.service.ExperimentsService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Tag(name = "Experiments", description = "Kubeflow 및 MLflow Experiment API")
@RestController
@RequestMapping("/api/experiments")
@RequiredArgsConstructor
@Slf4j
public class ExperimentsController {
private final ExperimentsService experimentsService;
private final WebClient.Builder webClientBuilder;
@Value("${kubeflow.url}")
private String kubeflowBaseUrl; // 예: http://192.168.10.135:32473/
@Value("${mlflow.url}")
private String mlflowBaseUrl; // 예: http://192.168.10.135:32473/
@Value("${mlflow.user}")
private String mlflowUser; // 예: http://192.168.10.135:32473/
@Value("${mlflow.password}")
private String mlflowPassword; // 예: http://192.168.10.135:32473/
@Operation(summary = "모든 Experiments 조회")
@GetMapping
public ResponseEntity<List<ExperimentsEntity>> getAllExperiments() {
return ResponseEntity.ok(experimentsService.findAll());
}
@Operation(summary = "Experiment 단건 조회")
@GetMapping("/{id}")
public ResponseEntity<ExperimentsEntity> getExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id) {
return experimentsService.findById(id)
.map(ResponseEntity::ok)
.orElse(ResponseEntity.notFound().build());
}
@Operation(summary = "Experiment 검색 및 페이지네이션")
@GetMapping("/search")
public ResponseEntity<Page<ExperimentsEntity>> searchExperiments(
@ParameterObject @ModelAttribute ProjectBaseSearchRequest request) {
Page<ExperimentsEntity> page = experimentsService.search(request);
return ResponseEntity.ok(page);
}
// @Operation(summary = "Experiment 등록")
// @PostMapping
// public Mono<ResponseEntity<ExperimentsEntity>> createExperiment(@RequestBody ExperimentsEntity experiment) {
//
// // 1⃣ DB 저장
// ExperimentsEntity saved = experimentsService.save(experiment);
//
// // 2⃣ Kubeflow POST 요청 payload
// Map<String, Object> payload = new HashMap<>();
// payload.put("display_name", saved.getDisplayName());
// payload.put("description", saved.getDescription());
// payload.put("namespace", "default"); // 필요에 따라 변경
//
// // 3⃣ WebClient POST
// return webClientBuilder.build()
// .post()
// .uri(kubeflowBaseUrl + "/apis/v2beta1/experiments")
// .contentType(MediaType.APPLICATION_JSON)
// .bodyValue(payload)
// .retrieve()
// .bodyToMono(Map.class) // Kubeflow 응답
// .map(resp -> {
// // resp에서 필요한 값 추출 후 entity에 반영
// if(resp.get("id") != null) {
// saved.setKubeFlowId(resp.get("id").toString());
// }
//
// if (resp.get("storage_state") != null) {
// saved.setKubeflowStorageState(resp.get("storage_state").toString());
// }
//
// if (resp.get("created_at") != null) {
// String lastRunStr = resp.get("created_at").toString();
// OffsetDateTime odt = OffsetDateTime.parse(lastRunStr);
// saved.setKubeflowCreatedAt(odt.withOffsetSameInstant(ZoneId.of("Asia/Seoul").getRules().getOffset(odt.toInstant()))
// .toLocalDateTime());
// }
//
// if (resp.get("last_run_created_at") != null) {
// String lastRunStr = resp.get("last_run_created_at").toString();
// OffsetDateTime odt = OffsetDateTime.parse(lastRunStr);
// saved.setKubeflowLastRunCreatedAt(odt.withOffsetSameInstant(ZoneId.of("Asia/Seoul").getRules().getOffset(odt.toInstant()))
// .toLocalDateTime());
// }
// return saved;
// })
// .map(ResponseEntity::ok)
// .doOnError(e -> log.error("Kubeflow experiment 등록 실패", e));
// }
@PostMapping
@Transactional
public Mono<ResponseEntity<ExperimentsEntity>> createExperiment(@RequestBody ExperimentsEntity experiment) {
// DB에 미리 저장하지 않고, 메모리에 보관
ExperimentsEntity saved = new ExperimentsEntity();
saved.setDisplayName(experiment.getDisplayName());
saved.setDescription(experiment.getDescription());
saved.setProjectId(experiment.getProjectId());
saved.setRegUserId(experiment.getRegUserId());
Map<String, Object> kubeflowPayload = new HashMap<>();
kubeflowPayload.put("display_name", saved.getDisplayName());
kubeflowPayload.put("description", saved.getDescription());
kubeflowPayload.put("namespace", "default");
return webClientBuilder.build()
// 1⃣ Kubeflow 등록
.post()
.uri(kubeflowBaseUrl + "/apis/v2beta1/experiments")
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(kubeflowPayload)
.retrieve()
.bodyToMono(Map.class)
.flatMap(kubeflowResp -> {
if (kubeflowResp.containsKey("experiment_id")) {
saved.setKubeFlowId((String) kubeflowResp.get("experiment_id"));
}
if (kubeflowResp.containsKey("created_at")) {
saved.setKubeflowCreatedAt(
Instant.parse((String) kubeflowResp.get("created_at"))
.atZone(ZoneId.of("Asia/Seoul"))
.toLocalDateTime()
);
}
log.info("Kubeflow experiment 등록 완료: {}", kubeflowResp);
// 2⃣ MLflow 등록
Map<String, Object> mlflowPayload = new HashMap<>();
mlflowPayload.put("name", saved.getDisplayName());
mlflowPayload.put("artifact_location", "/default/artifacts");
return webClientBuilder.build()
.post()
.uri(mlflowBaseUrl + "/ajax-api/2.0/mlflow/experiments/create")
.headers(headers -> headers.setBasicAuth(mlflowUser, mlflowPassword))
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(mlflowPayload)
.retrieve()
.bodyToMono(Map.class)
.flatMap(createResp -> {
log.info("MLflow experiment 등록 완료: {}", createResp);
String mlflowExpId = (String) createResp.get("experiment_id");
return webClientBuilder.build()
.get()
.uri(mlflowBaseUrl + "/ajax-api/2.0/mlflow/experiments/get?experiment_id=" + mlflowExpId)
.headers(headers -> headers.setBasicAuth(mlflowUser, mlflowPassword))
.retrieve()
.bodyToMono(Map.class)
.map(getResp -> {
if (getResp.containsKey("experiment")) {
Map<String, Object> exp = (Map<String, Object>) getResp.get("experiment");
saved.setMlFlowId((String) exp.get("experiment_id"));
saved.setMlflow_artifactLocation((String) exp.get("artifact_location"));
saved.setMlflowLifecycleStage((String) exp.get("lifecycle_stage"));
}
// DB 저장은 외부 호출이 모두 성공한 이후
ExperimentsEntity finalSaved = experimentsService.save(saved);
return ResponseEntity.ok(finalSaved);
});
});
})
.doOnError(e -> log.error("Experiment 등록 실패", e));
}
@Operation(summary = "Experiment 수정")
@PutMapping("/{id}")
public ResponseEntity<ExperimentsEntity> updateExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id,
@RequestBody ExperimentsEntity experiment) {
return experimentsService.findById(id)
.map(existing -> {
experiment.setId(id);
return ResponseEntity.ok(experimentsService.save(experiment));
})
.orElse(ResponseEntity.notFound().build());
}
@Operation(summary = "Experiment 삭제")
@DeleteMapping("/{id}")
public ResponseEntity<Void> deleteExperiment(
@Parameter(description = "Experiment ID", example = "1") @PathVariable("id") Long id) {
if (experimentsService.deleteById(id)) {
return ResponseEntity.noContent().build();
}
return ResponseEntity.notFound().build();
}
}