// backend\nsfw_detector.go package main import ( "bytes" "encoding/base64" "fmt" "image" _ "image/jpeg" _ "image/png" "math" "os" "path/filepath" "sort" "strings" "sync" ort "github.com/yalue/onnxruntime_go" xdraw "golang.org/x/image/draw" ) const ( nsfwInputSize = 320 nsfwNumClasses = 18 nsfwNumAnchors = 2100 // 320er YOLOv8: 40*40 + 20*20 + 10*10 nsfwConfThresh = 0.20 nsfwNMSThresh = 0.45 ) var nsfwLabels = []string{ "female_genitalia_covered", "face_female", "buttocks_exposed", "female_breast_exposed", "female_genitalia_exposed", "male_breast_exposed", "anus_exposed", "feet_exposed", "belly_covered", "feet_covered", "armpits_covered", "armpits_exposed", "face_male", "belly_exposed", "male_genitalia_exposed", "anus_covered", "female_breast_covered", "buttocks_covered", } type nsfwDetector struct { mu sync.Mutex initialized bool runtimeRoot string modelPath string dllPath string inputTensor *ort.Tensor[float32] outputTensor *ort.Tensor[float32] session *ort.AdvancedSession } type yoloDet struct { classID int score float32 x1 float32 y1 float32 x2 float32 y2 float32 } var globalNSFW nsfwDetector func initNSFWDetector() error { globalNSFW.mu.Lock() defer globalNSFW.mu.Unlock() if globalNSFW.initialized { return nil } root, err := ensureNSFWAssetsExtracted() if err != nil { return err } dllPath := filepath.Join(root, "onnxruntime.dll") modelPath := filepath.Join(root, "320n.onnx") if _, err := os.Stat(dllPath); err != nil { return fmt.Errorf("onnxruntime.dll nicht gefunden: %w", err) } if _, err := os.Stat(modelPath); err != nil { return fmt.Errorf("320n.onnx nicht gefunden: %w", err) } ort.SetSharedLibraryPath(dllPath) if err := ort.InitializeEnvironment(); err != nil { return fmt.Errorf("onnxruntime init fehlgeschlagen: %w", err) } inputShape := ort.NewShape(1, 3, nsfwInputSize, nsfwInputSize) inputData := make([]float32, 1*3*nsfwInputSize*nsfwInputSize) inputTensor, err := ort.NewTensor(inputShape, inputData) if err != nil { ort.DestroyEnvironment() return fmt.Errorf("input tensor fehlgeschlagen: %w", err) } outputShape := ort.NewShape(1, 4+nsfwNumClasses, nsfwNumAnchors) outputTensor, err := ort.NewEmptyTensor[float32](outputShape) if err != nil { inputTensor.Destroy() ort.DestroyEnvironment() return fmt.Errorf("output tensor fehlgeschlagen: %w", err) } session, err := ort.NewAdvancedSession( modelPath, []string{"images"}, []string{"output0"}, []ort.Value{inputTensor}, []ort.Value{outputTensor}, nil, ) if err != nil { outputTensor.Destroy() inputTensor.Destroy() ort.DestroyEnvironment() return fmt.Errorf("onnx session fehlgeschlagen: %w", err) } globalNSFW.runtimeRoot = root globalNSFW.modelPath = modelPath globalNSFW.dllPath = dllPath globalNSFW.inputTensor = inputTensor globalNSFW.outputTensor = outputTensor globalNSFW.session = session globalNSFW.initialized = true return nil } func closeNSFWDetector() error { globalNSFW.mu.Lock() defer globalNSFW.mu.Unlock() if !globalNSFW.initialized { return nil } if globalNSFW.session != nil { globalNSFW.session.Destroy() globalNSFW.session = nil } if globalNSFW.outputTensor != nil { globalNSFW.outputTensor.Destroy() globalNSFW.outputTensor = nil } if globalNSFW.inputTensor != nil { globalNSFW.inputTensor.Destroy() globalNSFW.inputTensor = nil } ort.DestroyEnvironment() globalNSFW.initialized = false return nil } func detectNSFWFromBase64(imageB64 string) ([]NsfwFrameResult, error) { globalNSFW.mu.Lock() defer globalNSFW.mu.Unlock() if !globalNSFW.initialized || globalNSFW.session == nil { return nil, fmt.Errorf("nsfw detector nicht initialisiert") } img, err := decodeBase64Image(imageB64) if err != nil { return nil, err } fillInputTensor(globalNSFW.inputTensor.GetData(), img) if err := globalNSFW.session.Run(); err != nil { return nil, fmt.Errorf("onnx run fehlgeschlagen: %w", err) } raw := globalNSFW.outputTensor.GetData() dets := parseYOLOOutput(raw, nsfwConfThresh) dets = applyNMS(dets, nsfwNMSThresh) bestByLabel := map[string]float64{} for _, d := range dets { if d.classID < 0 || d.classID >= len(nsfwLabels) { continue } label := nsfwLabels[d.classID] score := float64(d.score) if score > bestByLabel[label] { bestByLabel[label] = score } } out := make([]NsfwFrameResult, 0, len(bestByLabel)) for label, score := range bestByLabel { out = append(out, NsfwFrameResult{ Label: label, Score: score, }) } sort.Slice(out, func(i, j int) bool { return out[i].Score > out[j].Score }) return out, nil } func decodeBase64Image(imageB64 string) (image.Image, error) { raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(imageB64)) if err != nil { return nil, fmt.Errorf("base64 decode fehlgeschlagen: %w", err) } img, _, err := image.Decode(bytes.NewReader(raw)) if err != nil { return nil, fmt.Errorf("bild decode fehlgeschlagen: %w", err) } return img, nil } func fillInputTensor(dst []float32, src image.Image) { rgba, scale, padX, padY := letterboxToRGBA(src, nsfwInputSize, nsfwInputSize) hw := nsfwInputSize * nsfwInputSize for y := 0; y < nsfwInputSize; y++ { for x := 0; x < nsfwInputSize; x++ { i := y*rgba.Stride + x*4 r := float32(rgba.Pix[i+0]) / 255.0 g := float32(rgba.Pix[i+1]) / 255.0 b := float32(rgba.Pix[i+2]) / 255.0 idx := y*nsfwInputSize + x dst[idx] = r dst[hw+idx] = g dst[2*hw+idx] = b } } _ = scale _ = padX _ = padY } func letterboxToRGBA(src image.Image, dstW, dstH int) (*image.RGBA, float64, int, int) { sb := src.Bounds() sw := sb.Dx() sh := sb.Dy() scale := math.Min(float64(dstW)/float64(sw), float64(dstH)/float64(sh)) nw := int(math.Round(float64(sw) * scale)) nh := int(math.Round(float64(sh) * scale)) dst := image.NewRGBA(image.Rect(0, 0, dstW, dstH)) for y := 0; y < dstH; y++ { for x := 0; x < dstW; x++ { i := y*dst.Stride + x*4 dst.Pix[i+0] = 114 dst.Pix[i+1] = 114 dst.Pix[i+2] = 114 dst.Pix[i+3] = 255 } } resized := image.NewRGBA(image.Rect(0, 0, nw, nh)) xdraw.ApproxBiLinear.Scale(resized, resized.Bounds(), src, sb, xdraw.Over, nil) padX := (dstW - nw) / 2 padY := (dstH - nh) / 2 for y := 0; y < nh; y++ { copy( dst.Pix[(y+padY)*dst.Stride+padX*4:(y+padY)*dst.Stride+padX*4+nw*4], resized.Pix[y*resized.Stride:y*resized.Stride+nw*4], ) } return dst, scale, padX, padY } func parseYOLOOutput(raw []float32, confThresh float32) []yoloDet { // output0: [1, 22, 2100] = [batch, 4+18, anchors] out := make([]yoloDet, 0, 64) channels := 4 + nsfwNumClasses if len(raw) != channels*nsfwNumAnchors { return out } for a := 0; a < nsfwNumAnchors; a++ { cx := raw[0*nsfwNumAnchors+a] cy := raw[1*nsfwNumAnchors+a] w := raw[2*nsfwNumAnchors+a] h := raw[3*nsfwNumAnchors+a] bestClass := -1 bestScore := float32(0) for c := 0; c < nsfwNumClasses; c++ { s := raw[(4+c)*nsfwNumAnchors+a] if s > bestScore { bestScore = s bestClass = c } } if bestClass < 0 || bestScore < confThresh { continue } x1 := cx - w/2 y1 := cy - h/2 x2 := cx + w/2 y2 := cy + h/2 out = append(out, yoloDet{ classID: bestClass, score: bestScore, x1: x1, y1: y1, x2: x2, y2: y2, }) } return out } func applyNMS(dets []yoloDet, iouThresh float32) []yoloDet { if len(dets) == 0 { return dets } sort.Slice(dets, func(i, j int) bool { return dets[i].score > dets[j].score }) kept := make([]yoloDet, 0, len(dets)) used := make([]bool, len(dets)) for i := 0; i < len(dets); i++ { if used[i] { continue } kept = append(kept, dets[i]) for j := i + 1; j < len(dets); j++ { if used[j] || dets[i].classID != dets[j].classID { continue } if iou(dets[i], dets[j]) >= iouThresh { used[j] = true } } } return kept } func iou(a, b yoloDet) float32 { ix1 := maxf(a.x1, b.x1) iy1 := maxf(a.y1, b.y1) ix2 := minf(a.x2, b.x2) iy2 := minf(a.y2, b.y2) iw := maxf(0, ix2-ix1) ih := maxf(0, iy2-iy1) inter := iw * ih aw := maxf(0, a.x2-a.x1) ah := maxf(0, a.y2-a.y1) bw := maxf(0, b.x2-b.x1) bh := maxf(0, b.y2-b.y1) union := aw*ah + bw*bh - inter if union <= 0 { return 0 } return inter / union } func minf(a, b float32) float32 { if a < b { return a } return b } func maxf(a, b float32) float32 { if a > b { return a } return b }