Conversion_Kitchen_Code/kitchen_counter/scriptAudit/audit_ai.py

77 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import pandas as pd
import os
from openai import OpenAI
import sys
def classify_lines(input_file_path, audit_ai_csv) -> list[str]:
from scriptAudit.utils import remove_empty_content,remove_leading_numbers, remove_numeric_only_content,remove_emptyline_rows,merge_consecutive_action_lines,merge_consecutive_dialogue_lines
from scriptAudit.utils import insert_blank_lines,add_fade_in_out,remove_asterisks, merge_consecutive_action_lines_new,extract_labeled_lines, remove_trailing_speaker
with open(input_file_path, "r") as f:
raw_lines = [line.strip() for line in f.readlines() if line.strip()]
chunked_results = []
client = OpenAI(api_key=os.getenv('openai_key'))
prompt = (
"You are a screenplay assistant. For each line below, classify it using one of these labels: "
"slugline, speaker, dialogue, action, parenthetical, transition. Return each line followed by its label in curly braces.\n\n"
)
for i in range(0, len(raw_lines), 20):
chunk = raw_lines[i:i + 20]
final_prompt = prompt + "\n".join(chunk)
response = client.responses.create(
model="gpt-4o",
input=[
{
"role": "developer",
"content": """You are a screenplay auditor. For each line below, classify it using one of these labels:
slugline, speaker, dialogue, action, parenthetical, transition, special_term, title. Return each line followed by its label in curly braces.\n\n
**Examples:**
INT. ROOM NIGHT {slugline}
KITCHEN DAY {slugline}
JOHN {speaker}
(quietly) {parenthetical}
JOHN (O.S.) {speaker}
JOHN (angrily) {speaker}
I knew youd come. {dialogue}
She turns away from the window. {action}
FADE OUT. {transition}
THE END {title}
(V.O.) {special_term}
John CONT'D {speaker}"""
},
{
"role": "user",
"content": "I need you to classify the lines below. Please provide the classification in the format: 'line {label}'\n\n" + "\n".join(chunk)
}
]
)
classified = response.output_text.splitlines()
chunked_results.extend(classified)
extracted = extract_labeled_lines(chunked_results)
with open(audit_ai_csv, mode='w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(["content", "script_element"])
writer.writerows(extracted)
print("Classification completed.")
df = pd.read_csv(audit_ai_csv)
df = remove_empty_content(df)
df = remove_asterisks(df)
df = remove_leading_numbers(df)
df = remove_numeric_only_content(df)
df = remove_emptyline_rows(df)
df = remove_trailing_speaker(df)
# df = merge_consecutive_action_lines(df)
df = merge_consecutive_action_lines_new(df)
df = merge_consecutive_dialogue_lines(df)
df = insert_blank_lines(df)
df = add_fade_in_out(df)
return df