nsfwapp/backend/nsfw_detector.go
2026-03-16 12:46:38 +01:00

406 lines
8.3 KiB
Go

// backend\nsfw_detector.go
package main
import (
"bytes"
"encoding/base64"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"path/filepath"
"sort"
"strings"
"sync"
ort "github.com/yalue/onnxruntime_go"
xdraw "golang.org/x/image/draw"
)
const (
nsfwInputSize = 320
nsfwNumClasses = 18
nsfwNumAnchors = 2100 // 320er YOLOv8: 40*40 + 20*20 + 10*10
nsfwConfThresh = 0.20
nsfwNMSThresh = 0.45
)
var nsfwLabels = []string{
"female_genitalia_covered",
"face_female",
"buttocks_exposed",
"female_breast_exposed",
"female_genitalia_exposed",
"male_breast_exposed",
"anus_exposed",
"feet_exposed",
"belly_covered",
"feet_covered",
"armpits_covered",
"armpits_exposed",
"face_male",
"belly_exposed",
"male_genitalia_exposed",
"anus_covered",
"female_breast_covered",
"buttocks_covered",
}
type nsfwDetector struct {
mu sync.Mutex
initialized bool
runtimeRoot string
modelPath string
dllPath string
inputTensor *ort.Tensor[float32]
outputTensor *ort.Tensor[float32]
session *ort.AdvancedSession
}
type yoloDet struct {
classID int
score float32
x1 float32
y1 float32
x2 float32
y2 float32
}
var globalNSFW nsfwDetector
func initNSFWDetector() error {
globalNSFW.mu.Lock()
defer globalNSFW.mu.Unlock()
if globalNSFW.initialized {
return nil
}
root, err := ensureNSFWAssetsExtracted()
if err != nil {
return err
}
dllPath := filepath.Join(root, "onnxruntime.dll")
modelPath := filepath.Join(root, "320n.onnx")
if _, err := os.Stat(dllPath); err != nil {
return fmt.Errorf("onnxruntime.dll nicht gefunden: %w", err)
}
if _, err := os.Stat(modelPath); err != nil {
return fmt.Errorf("320n.onnx nicht gefunden: %w", err)
}
ort.SetSharedLibraryPath(dllPath)
if err := ort.InitializeEnvironment(); err != nil {
return fmt.Errorf("onnxruntime init fehlgeschlagen: %w", err)
}
inputShape := ort.NewShape(1, 3, nsfwInputSize, nsfwInputSize)
inputData := make([]float32, 1*3*nsfwInputSize*nsfwInputSize)
inputTensor, err := ort.NewTensor(inputShape, inputData)
if err != nil {
ort.DestroyEnvironment()
return fmt.Errorf("input tensor fehlgeschlagen: %w", err)
}
outputShape := ort.NewShape(1, 4+nsfwNumClasses, nsfwNumAnchors)
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
inputTensor.Destroy()
ort.DestroyEnvironment()
return fmt.Errorf("output tensor fehlgeschlagen: %w", err)
}
session, err := ort.NewAdvancedSession(
modelPath,
[]string{"images"},
[]string{"output0"},
[]ort.Value{inputTensor},
[]ort.Value{outputTensor},
nil,
)
if err != nil {
outputTensor.Destroy()
inputTensor.Destroy()
ort.DestroyEnvironment()
return fmt.Errorf("onnx session fehlgeschlagen: %w", err)
}
globalNSFW.runtimeRoot = root
globalNSFW.modelPath = modelPath
globalNSFW.dllPath = dllPath
globalNSFW.inputTensor = inputTensor
globalNSFW.outputTensor = outputTensor
globalNSFW.session = session
globalNSFW.initialized = true
return nil
}
func closeNSFWDetector() error {
globalNSFW.mu.Lock()
defer globalNSFW.mu.Unlock()
if !globalNSFW.initialized {
return nil
}
if globalNSFW.session != nil {
globalNSFW.session.Destroy()
globalNSFW.session = nil
}
if globalNSFW.outputTensor != nil {
globalNSFW.outputTensor.Destroy()
globalNSFW.outputTensor = nil
}
if globalNSFW.inputTensor != nil {
globalNSFW.inputTensor.Destroy()
globalNSFW.inputTensor = nil
}
ort.DestroyEnvironment()
globalNSFW.initialized = false
return nil
}
func detectNSFWFromBase64(imageB64 string) ([]NsfwFrameResult, error) {
globalNSFW.mu.Lock()
defer globalNSFW.mu.Unlock()
if !globalNSFW.initialized || globalNSFW.session == nil {
return nil, fmt.Errorf("nsfw detector nicht initialisiert")
}
img, err := decodeBase64Image(imageB64)
if err != nil {
return nil, err
}
fillInputTensor(globalNSFW.inputTensor.GetData(), img)
if err := globalNSFW.session.Run(); err != nil {
return nil, fmt.Errorf("onnx run fehlgeschlagen: %w", err)
}
raw := globalNSFW.outputTensor.GetData()
dets := parseYOLOOutput(raw, nsfwConfThresh)
dets = applyNMS(dets, nsfwNMSThresh)
bestByLabel := map[string]float64{}
for _, d := range dets {
if d.classID < 0 || d.classID >= len(nsfwLabels) {
continue
}
label := nsfwLabels[d.classID]
score := float64(d.score)
if score > bestByLabel[label] {
bestByLabel[label] = score
}
}
out := make([]NsfwFrameResult, 0, len(bestByLabel))
for label, score := range bestByLabel {
out = append(out, NsfwFrameResult{
Label: label,
Score: score,
})
}
sort.Slice(out, func(i, j int) bool {
return out[i].Score > out[j].Score
})
return out, nil
}
func decodeBase64Image(imageB64 string) (image.Image, error) {
raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(imageB64))
if err != nil {
return nil, fmt.Errorf("base64 decode fehlgeschlagen: %w", err)
}
img, _, err := image.Decode(bytes.NewReader(raw))
if err != nil {
return nil, fmt.Errorf("bild decode fehlgeschlagen: %w", err)
}
return img, nil
}
func fillInputTensor(dst []float32, src image.Image) {
rgba, scale, padX, padY := letterboxToRGBA(src, nsfwInputSize, nsfwInputSize)
hw := nsfwInputSize * nsfwInputSize
for y := 0; y < nsfwInputSize; y++ {
for x := 0; x < nsfwInputSize; x++ {
i := y*rgba.Stride + x*4
r := float32(rgba.Pix[i+0]) / 255.0
g := float32(rgba.Pix[i+1]) / 255.0
b := float32(rgba.Pix[i+2]) / 255.0
idx := y*nsfwInputSize + x
dst[idx] = r
dst[hw+idx] = g
dst[2*hw+idx] = b
}
}
_ = scale
_ = padX
_ = padY
}
func letterboxToRGBA(src image.Image, dstW, dstH int) (*image.RGBA, float64, int, int) {
sb := src.Bounds()
sw := sb.Dx()
sh := sb.Dy()
scale := math.Min(float64(dstW)/float64(sw), float64(dstH)/float64(sh))
nw := int(math.Round(float64(sw) * scale))
nh := int(math.Round(float64(sh) * scale))
dst := image.NewRGBA(image.Rect(0, 0, dstW, dstH))
for y := 0; y < dstH; y++ {
for x := 0; x < dstW; x++ {
i := y*dst.Stride + x*4
dst.Pix[i+0] = 114
dst.Pix[i+1] = 114
dst.Pix[i+2] = 114
dst.Pix[i+3] = 255
}
}
resized := image.NewRGBA(image.Rect(0, 0, nw, nh))
xdraw.ApproxBiLinear.Scale(resized, resized.Bounds(), src, sb, xdraw.Over, nil)
padX := (dstW - nw) / 2
padY := (dstH - nh) / 2
for y := 0; y < nh; y++ {
copy(
dst.Pix[(y+padY)*dst.Stride+padX*4:(y+padY)*dst.Stride+padX*4+nw*4],
resized.Pix[y*resized.Stride:y*resized.Stride+nw*4],
)
}
return dst, scale, padX, padY
}
func parseYOLOOutput(raw []float32, confThresh float32) []yoloDet {
// output0: [1, 22, 2100] = [batch, 4+18, anchors]
out := make([]yoloDet, 0, 64)
channels := 4 + nsfwNumClasses
if len(raw) != channels*nsfwNumAnchors {
return out
}
for a := 0; a < nsfwNumAnchors; a++ {
cx := raw[0*nsfwNumAnchors+a]
cy := raw[1*nsfwNumAnchors+a]
w := raw[2*nsfwNumAnchors+a]
h := raw[3*nsfwNumAnchors+a]
bestClass := -1
bestScore := float32(0)
for c := 0; c < nsfwNumClasses; c++ {
s := raw[(4+c)*nsfwNumAnchors+a]
if s > bestScore {
bestScore = s
bestClass = c
}
}
if bestClass < 0 || bestScore < confThresh {
continue
}
x1 := cx - w/2
y1 := cy - h/2
x2 := cx + w/2
y2 := cy + h/2
out = append(out, yoloDet{
classID: bestClass,
score: bestScore,
x1: x1,
y1: y1,
x2: x2,
y2: y2,
})
}
return out
}
func applyNMS(dets []yoloDet, iouThresh float32) []yoloDet {
if len(dets) == 0 {
return dets
}
sort.Slice(dets, func(i, j int) bool {
return dets[i].score > dets[j].score
})
kept := make([]yoloDet, 0, len(dets))
used := make([]bool, len(dets))
for i := 0; i < len(dets); i++ {
if used[i] {
continue
}
kept = append(kept, dets[i])
for j := i + 1; j < len(dets); j++ {
if used[j] || dets[i].classID != dets[j].classID {
continue
}
if iou(dets[i], dets[j]) >= iouThresh {
used[j] = true
}
}
}
return kept
}
func iou(a, b yoloDet) float32 {
ix1 := maxf(a.x1, b.x1)
iy1 := maxf(a.y1, b.y1)
ix2 := minf(a.x2, b.x2)
iy2 := minf(a.y2, b.y2)
iw := maxf(0, ix2-ix1)
ih := maxf(0, iy2-iy1)
inter := iw * ih
aw := maxf(0, a.x2-a.x1)
ah := maxf(0, a.y2-a.y1)
bw := maxf(0, b.x2-b.x1)
bh := maxf(0, b.y2-b.y1)
union := aw*ah + bw*bh - inter
if union <= 0 {
return 0
}
return inter / union
}
func minf(a, b float32) float32 {
if a < b {
return a
}
return b
}
func maxf(a, b float32) float32 {
if a > b {
return a
}
return b
}