406 lines
8.3 KiB
Go
406 lines
8.3 KiB
Go
// backend\nsfw_detector.go
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"image"
|
|
_ "image/jpeg"
|
|
_ "image/png"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
xdraw "golang.org/x/image/draw"
|
|
)
|
|
|
|
const (
|
|
nsfwInputSize = 320
|
|
nsfwNumClasses = 18
|
|
nsfwNumAnchors = 2100 // 320er YOLOv8: 40*40 + 20*20 + 10*10
|
|
nsfwConfThresh = 0.20
|
|
nsfwNMSThresh = 0.45
|
|
)
|
|
|
|
var nsfwLabels = []string{
|
|
"female_genitalia_covered",
|
|
"face_female",
|
|
"buttocks_exposed",
|
|
"female_breast_exposed",
|
|
"female_genitalia_exposed",
|
|
"male_breast_exposed",
|
|
"anus_exposed",
|
|
"feet_exposed",
|
|
"belly_covered",
|
|
"feet_covered",
|
|
"armpits_covered",
|
|
"armpits_exposed",
|
|
"face_male",
|
|
"belly_exposed",
|
|
"male_genitalia_exposed",
|
|
"anus_covered",
|
|
"female_breast_covered",
|
|
"buttocks_covered",
|
|
}
|
|
|
|
type nsfwDetector struct {
|
|
mu sync.Mutex
|
|
initialized bool
|
|
runtimeRoot string
|
|
modelPath string
|
|
dllPath string
|
|
inputTensor *ort.Tensor[float32]
|
|
outputTensor *ort.Tensor[float32]
|
|
session *ort.AdvancedSession
|
|
}
|
|
|
|
type yoloDet struct {
|
|
classID int
|
|
score float32
|
|
x1 float32
|
|
y1 float32
|
|
x2 float32
|
|
y2 float32
|
|
}
|
|
|
|
var globalNSFW nsfwDetector
|
|
|
|
func initNSFWDetector() error {
|
|
globalNSFW.mu.Lock()
|
|
defer globalNSFW.mu.Unlock()
|
|
|
|
if globalNSFW.initialized {
|
|
return nil
|
|
}
|
|
|
|
root, err := ensureNSFWAssetsExtracted()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dllPath := filepath.Join(root, "onnxruntime.dll")
|
|
modelPath := filepath.Join(root, "320n.onnx")
|
|
|
|
if _, err := os.Stat(dllPath); err != nil {
|
|
return fmt.Errorf("onnxruntime.dll nicht gefunden: %w", err)
|
|
}
|
|
if _, err := os.Stat(modelPath); err != nil {
|
|
return fmt.Errorf("320n.onnx nicht gefunden: %w", err)
|
|
}
|
|
|
|
ort.SetSharedLibraryPath(dllPath)
|
|
if err := ort.InitializeEnvironment(); err != nil {
|
|
return fmt.Errorf("onnxruntime init fehlgeschlagen: %w", err)
|
|
}
|
|
|
|
inputShape := ort.NewShape(1, 3, nsfwInputSize, nsfwInputSize)
|
|
inputData := make([]float32, 1*3*nsfwInputSize*nsfwInputSize)
|
|
inputTensor, err := ort.NewTensor(inputShape, inputData)
|
|
if err != nil {
|
|
ort.DestroyEnvironment()
|
|
return fmt.Errorf("input tensor fehlgeschlagen: %w", err)
|
|
}
|
|
|
|
outputShape := ort.NewShape(1, 4+nsfwNumClasses, nsfwNumAnchors)
|
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
if err != nil {
|
|
inputTensor.Destroy()
|
|
ort.DestroyEnvironment()
|
|
return fmt.Errorf("output tensor fehlgeschlagen: %w", err)
|
|
}
|
|
|
|
session, err := ort.NewAdvancedSession(
|
|
modelPath,
|
|
[]string{"images"},
|
|
[]string{"output0"},
|
|
[]ort.Value{inputTensor},
|
|
[]ort.Value{outputTensor},
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
outputTensor.Destroy()
|
|
inputTensor.Destroy()
|
|
ort.DestroyEnvironment()
|
|
return fmt.Errorf("onnx session fehlgeschlagen: %w", err)
|
|
}
|
|
|
|
globalNSFW.runtimeRoot = root
|
|
globalNSFW.modelPath = modelPath
|
|
globalNSFW.dllPath = dllPath
|
|
globalNSFW.inputTensor = inputTensor
|
|
globalNSFW.outputTensor = outputTensor
|
|
globalNSFW.session = session
|
|
globalNSFW.initialized = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func closeNSFWDetector() error {
|
|
globalNSFW.mu.Lock()
|
|
defer globalNSFW.mu.Unlock()
|
|
|
|
if !globalNSFW.initialized {
|
|
return nil
|
|
}
|
|
|
|
if globalNSFW.session != nil {
|
|
globalNSFW.session.Destroy()
|
|
globalNSFW.session = nil
|
|
}
|
|
if globalNSFW.outputTensor != nil {
|
|
globalNSFW.outputTensor.Destroy()
|
|
globalNSFW.outputTensor = nil
|
|
}
|
|
if globalNSFW.inputTensor != nil {
|
|
globalNSFW.inputTensor.Destroy()
|
|
globalNSFW.inputTensor = nil
|
|
}
|
|
|
|
ort.DestroyEnvironment()
|
|
globalNSFW.initialized = false
|
|
|
|
return nil
|
|
}
|
|
|
|
func detectNSFWFromBase64(imageB64 string) ([]NsfwFrameResult, error) {
|
|
globalNSFW.mu.Lock()
|
|
defer globalNSFW.mu.Unlock()
|
|
|
|
if !globalNSFW.initialized || globalNSFW.session == nil {
|
|
return nil, fmt.Errorf("nsfw detector nicht initialisiert")
|
|
}
|
|
|
|
img, err := decodeBase64Image(imageB64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fillInputTensor(globalNSFW.inputTensor.GetData(), img)
|
|
|
|
if err := globalNSFW.session.Run(); err != nil {
|
|
return nil, fmt.Errorf("onnx run fehlgeschlagen: %w", err)
|
|
}
|
|
|
|
raw := globalNSFW.outputTensor.GetData()
|
|
dets := parseYOLOOutput(raw, nsfwConfThresh)
|
|
dets = applyNMS(dets, nsfwNMSThresh)
|
|
|
|
bestByLabel := map[string]float64{}
|
|
for _, d := range dets {
|
|
if d.classID < 0 || d.classID >= len(nsfwLabels) {
|
|
continue
|
|
}
|
|
label := nsfwLabels[d.classID]
|
|
score := float64(d.score)
|
|
if score > bestByLabel[label] {
|
|
bestByLabel[label] = score
|
|
}
|
|
}
|
|
|
|
out := make([]NsfwFrameResult, 0, len(bestByLabel))
|
|
for label, score := range bestByLabel {
|
|
out = append(out, NsfwFrameResult{
|
|
Label: label,
|
|
Score: score,
|
|
})
|
|
}
|
|
|
|
sort.Slice(out, func(i, j int) bool {
|
|
return out[i].Score > out[j].Score
|
|
})
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func decodeBase64Image(imageB64 string) (image.Image, error) {
|
|
raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(imageB64))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("base64 decode fehlgeschlagen: %w", err)
|
|
}
|
|
img, _, err := image.Decode(bytes.NewReader(raw))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("bild decode fehlgeschlagen: %w", err)
|
|
}
|
|
return img, nil
|
|
}
|
|
|
|
func fillInputTensor(dst []float32, src image.Image) {
|
|
rgba, scale, padX, padY := letterboxToRGBA(src, nsfwInputSize, nsfwInputSize)
|
|
|
|
hw := nsfwInputSize * nsfwInputSize
|
|
for y := 0; y < nsfwInputSize; y++ {
|
|
for x := 0; x < nsfwInputSize; x++ {
|
|
i := y*rgba.Stride + x*4
|
|
r := float32(rgba.Pix[i+0]) / 255.0
|
|
g := float32(rgba.Pix[i+1]) / 255.0
|
|
b := float32(rgba.Pix[i+2]) / 255.0
|
|
|
|
idx := y*nsfwInputSize + x
|
|
dst[idx] = r
|
|
dst[hw+idx] = g
|
|
dst[2*hw+idx] = b
|
|
}
|
|
}
|
|
|
|
_ = scale
|
|
_ = padX
|
|
_ = padY
|
|
}
|
|
|
|
func letterboxToRGBA(src image.Image, dstW, dstH int) (*image.RGBA, float64, int, int) {
|
|
sb := src.Bounds()
|
|
sw := sb.Dx()
|
|
sh := sb.Dy()
|
|
|
|
scale := math.Min(float64(dstW)/float64(sw), float64(dstH)/float64(sh))
|
|
nw := int(math.Round(float64(sw) * scale))
|
|
nh := int(math.Round(float64(sh) * scale))
|
|
|
|
dst := image.NewRGBA(image.Rect(0, 0, dstW, dstH))
|
|
|
|
for y := 0; y < dstH; y++ {
|
|
for x := 0; x < dstW; x++ {
|
|
i := y*dst.Stride + x*4
|
|
dst.Pix[i+0] = 114
|
|
dst.Pix[i+1] = 114
|
|
dst.Pix[i+2] = 114
|
|
dst.Pix[i+3] = 255
|
|
}
|
|
}
|
|
|
|
resized := image.NewRGBA(image.Rect(0, 0, nw, nh))
|
|
xdraw.ApproxBiLinear.Scale(resized, resized.Bounds(), src, sb, xdraw.Over, nil)
|
|
|
|
padX := (dstW - nw) / 2
|
|
padY := (dstH - nh) / 2
|
|
|
|
for y := 0; y < nh; y++ {
|
|
copy(
|
|
dst.Pix[(y+padY)*dst.Stride+padX*4:(y+padY)*dst.Stride+padX*4+nw*4],
|
|
resized.Pix[y*resized.Stride:y*resized.Stride+nw*4],
|
|
)
|
|
}
|
|
|
|
return dst, scale, padX, padY
|
|
}
|
|
|
|
func parseYOLOOutput(raw []float32, confThresh float32) []yoloDet {
|
|
// output0: [1, 22, 2100] = [batch, 4+18, anchors]
|
|
out := make([]yoloDet, 0, 64)
|
|
channels := 4 + nsfwNumClasses
|
|
if len(raw) != channels*nsfwNumAnchors {
|
|
return out
|
|
}
|
|
|
|
for a := 0; a < nsfwNumAnchors; a++ {
|
|
cx := raw[0*nsfwNumAnchors+a]
|
|
cy := raw[1*nsfwNumAnchors+a]
|
|
w := raw[2*nsfwNumAnchors+a]
|
|
h := raw[3*nsfwNumAnchors+a]
|
|
|
|
bestClass := -1
|
|
bestScore := float32(0)
|
|
|
|
for c := 0; c < nsfwNumClasses; c++ {
|
|
s := raw[(4+c)*nsfwNumAnchors+a]
|
|
if s > bestScore {
|
|
bestScore = s
|
|
bestClass = c
|
|
}
|
|
}
|
|
|
|
if bestClass < 0 || bestScore < confThresh {
|
|
continue
|
|
}
|
|
|
|
x1 := cx - w/2
|
|
y1 := cy - h/2
|
|
x2 := cx + w/2
|
|
y2 := cy + h/2
|
|
|
|
out = append(out, yoloDet{
|
|
classID: bestClass,
|
|
score: bestScore,
|
|
x1: x1,
|
|
y1: y1,
|
|
x2: x2,
|
|
y2: y2,
|
|
})
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func applyNMS(dets []yoloDet, iouThresh float32) []yoloDet {
|
|
if len(dets) == 0 {
|
|
return dets
|
|
}
|
|
|
|
sort.Slice(dets, func(i, j int) bool {
|
|
return dets[i].score > dets[j].score
|
|
})
|
|
|
|
kept := make([]yoloDet, 0, len(dets))
|
|
used := make([]bool, len(dets))
|
|
|
|
for i := 0; i < len(dets); i++ {
|
|
if used[i] {
|
|
continue
|
|
}
|
|
kept = append(kept, dets[i])
|
|
|
|
for j := i + 1; j < len(dets); j++ {
|
|
if used[j] || dets[i].classID != dets[j].classID {
|
|
continue
|
|
}
|
|
if iou(dets[i], dets[j]) >= iouThresh {
|
|
used[j] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
return kept
|
|
}
|
|
|
|
func iou(a, b yoloDet) float32 {
|
|
ix1 := maxf(a.x1, b.x1)
|
|
iy1 := maxf(a.y1, b.y1)
|
|
ix2 := minf(a.x2, b.x2)
|
|
iy2 := minf(a.y2, b.y2)
|
|
|
|
iw := maxf(0, ix2-ix1)
|
|
ih := maxf(0, iy2-iy1)
|
|
inter := iw * ih
|
|
|
|
aw := maxf(0, a.x2-a.x1)
|
|
ah := maxf(0, a.y2-a.y1)
|
|
bw := maxf(0, b.x2-b.x1)
|
|
bh := maxf(0, b.y2-b.y1)
|
|
|
|
union := aw*ah + bw*bh - inter
|
|
if union <= 0 {
|
|
return 0
|
|
}
|
|
return inter / union
|
|
}
|
|
|
|
func minf(a, b float32) float32 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func maxf(a, b float32) float32 {
|
|
if a > b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|