157 lines
4.9 KiB
Python
157 lines
4.9 KiB
Python
import os
|
|
import cv2
|
|
import tarfile
|
|
import pydload
|
|
import logging
|
|
import numpy as np
|
|
from .video_utils import get_interest_frames_from_video
|
|
from .image_utils import load_images
|
|
from PIL import Image as pil_image
|
|
from resource_fetch import ResourceFetch
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
class Classifier:
|
|
"""
|
|
Class for loading model and running predictions.
|
|
For example on how to use take a look the if __name__ == '__main__' part.
|
|
"""
|
|
|
|
nsfw_model = None
|
|
|
|
def __init__(self):
|
|
"""
|
|
model = Classifier()
|
|
"""
|
|
self.rf = ResourceFetch()
|
|
model_path = os.path.join(self.rf.fetch("Intellivoid", "CoffeeHouseData-Porn"), "classifier_model_tf")
|
|
|
|
self.nsfw_model = tf.contrib.predictor.from_saved_model(
|
|
model_path, signature_def_key="predict"
|
|
)
|
|
|
|
def classify_video(
|
|
self,
|
|
video_path,
|
|
batch_size=4,
|
|
image_size=(256, 256),
|
|
categories=["unsafe", "safe"],
|
|
):
|
|
frame_indices = None
|
|
frame_indices, frames, fps, video_length = get_interest_frames_from_video(
|
|
video_path
|
|
)
|
|
logging.debug(
|
|
f"VIDEO_PATH: {video_path}, FPS: {fps}, Important frame indices: {frame_indices}, Video length: {video_length}"
|
|
)
|
|
|
|
frames, frame_names = load_images(frames, image_size, image_names=frame_indices)
|
|
|
|
if not frame_names:
|
|
return {}
|
|
|
|
preds = []
|
|
model_preds = []
|
|
while len(frames):
|
|
_model_preds = self.nsfw_model({"images": frames[:batch_size]})["output"]
|
|
model_preds.append(_model_preds)
|
|
preds += np.argsort(_model_preds, axis=1).tolist()
|
|
frames = frames[batch_size:]
|
|
|
|
probs = []
|
|
for i, single_preds in enumerate(preds):
|
|
single_probs = []
|
|
for j, pred in enumerate(single_preds):
|
|
single_probs.append(
|
|
model_preds[int(i / batch_size)][int(i % batch_size)][pred]
|
|
)
|
|
preds[i][j] = categories[pred]
|
|
|
|
probs.append(single_probs)
|
|
|
|
return_preds = {
|
|
"metadata": {
|
|
"fps": fps,
|
|
"video_length": video_length,
|
|
"video_path": video_path,
|
|
},
|
|
"preds": {},
|
|
}
|
|
|
|
for i, frame_name in enumerate(frame_names):
|
|
return_preds["preds"][frame_name] = {}
|
|
for _ in range(len(preds[i])):
|
|
return_preds["preds"][frame_name][preds[i][_]] = probs[i][_]
|
|
|
|
return return_preds
|
|
|
|
def classify(
|
|
self,
|
|
image_paths=[],
|
|
batch_size=4,
|
|
image_size=(256, 256),
|
|
categories=["unsafe", "safe"],
|
|
):
|
|
"""
|
|
inputs:
|
|
image_paths: list of image paths or can be a string too (for single image)
|
|
batch_size: batch_size for running predictions
|
|
image_size: size to which the image needs to be resized
|
|
categories: since the model predicts numbers, categories is the list of actual names of categories
|
|
"""
|
|
if isinstance(image_paths, str):
|
|
image_paths = [image_paths]
|
|
|
|
loaded_images, loaded_image_paths = load_images(
|
|
image_paths, image_size, image_names=image_paths
|
|
)
|
|
|
|
if not loaded_image_paths:
|
|
return {}
|
|
|
|
preds = []
|
|
model_preds = []
|
|
while len(loaded_images):
|
|
_model_preds = self.nsfw_model({"images": loaded_images[:batch_size]})[
|
|
"output"
|
|
]
|
|
model_preds.append(_model_preds)
|
|
preds += np.argsort(_model_preds, axis=1).tolist()
|
|
loaded_images = loaded_images[batch_size:]
|
|
|
|
probs = []
|
|
for i, single_preds in enumerate(preds):
|
|
single_probs = []
|
|
for j, pred in enumerate(single_preds):
|
|
single_probs.append(
|
|
model_preds[int(i / batch_size)][int(i % batch_size)][pred]
|
|
)
|
|
preds[i][j] = categories[pred]
|
|
|
|
probs.append(single_probs)
|
|
|
|
images_preds = {}
|
|
|
|
for i, loaded_image_path in enumerate(loaded_image_paths):
|
|
if not isinstance(loaded_image_path, str):
|
|
loaded_image_path = i
|
|
|
|
images_preds[loaded_image_path] = {}
|
|
for _ in range(len(preds[i])):
|
|
images_preds[loaded_image_path][preds[i][_]] = probs[i][_]
|
|
|
|
return images_preds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
m = Classifier()
|
|
|
|
while 1:
|
|
print(
|
|
"\n Enter single image path or multiple images seperated by || (2 pipes) \n"
|
|
)
|
|
images = input().split("||")
|
|
images = [image.strip() for image in images]
|
|
print(m.predict(images), "\n")
|