* ONNX runtime
Browse files* use llm-guard 0.3.1
* google analytics tracking
* linter to fix code
- .pre-commit-config.yaml +38 -0
- Dockerfile +1 -1
- app.py +28 -19
- output.py +43 -27
- prompt.py +25 -45
- requirements.txt +4 -5
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v4.4.0
|
4 |
+
hooks:
|
5 |
+
- id: check-yaml
|
6 |
+
- id: end-of-file-fixer
|
7 |
+
- id: trailing-whitespace
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
types: [ python ]
|
10 |
+
- id: requirements-txt-fixer
|
11 |
+
|
12 |
+
- repo: https://github.com/psf/black
|
13 |
+
rev: 23.7.0
|
14 |
+
hooks:
|
15 |
+
- id: black
|
16 |
+
args: [ --line-length=100, --exclude="" ]
|
17 |
+
|
18 |
+
# this is not technically always safe but usually is
|
19 |
+
# use comments `# isort: off` and `# isort: on` to disable/re-enable isort
|
20 |
+
- repo: https://github.com/pycqa/isort
|
21 |
+
rev: 5.12.0
|
22 |
+
hooks:
|
23 |
+
- id: isort
|
24 |
+
args: [ --line-length=100, --profile=black ]
|
25 |
+
|
26 |
+
# this is slightly dangerous because python imports have side effects
|
27 |
+
# and this tool removes unused imports, which may be providing
|
28 |
+
# necessary side effects for the code to run
|
29 |
+
- repo: https://github.com/PyCQA/autoflake
|
30 |
+
rev: v2.2.0
|
31 |
+
hooks:
|
32 |
+
- id: autoflake
|
33 |
+
args:
|
34 |
+
- "--in-place"
|
35 |
+
- "--expand-star-imports"
|
36 |
+
- "--remove-duplicate-keys"
|
37 |
+
- "--remove-unused-variables"
|
38 |
+
- "--remove-all-unused-imports"
|
Dockerfile
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
RUN apt-get update && apt-get install -y \
|
4 |
build-essential \
|
|
|
1 |
+
FROM python:3.11-slim
|
2 |
|
3 |
RUN apt-get update && apt-get install -y \
|
4 |
build-essential \
|
app.py
CHANGED
@@ -1,16 +1,33 @@
|
|
1 |
import logging
|
2 |
-
import time
|
3 |
import traceback
|
4 |
-
from datetime import timedelta
|
5 |
|
6 |
import pandas as pd
|
7 |
import streamlit as st
|
|
|
|
|
|
|
8 |
from output import init_settings as init_output_settings
|
9 |
from output import scan as scan_output
|
10 |
from prompt import init_settings as init_prompt_settings
|
11 |
from prompt import scan as scan_prompt
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
PROMPT = "prompt"
|
16 |
OUTPUT = "output"
|
@@ -48,6 +65,8 @@ if scanner_type == PROMPT:
|
|
48 |
elif scanner_type == OUTPUT:
|
49 |
enabled_scanners, settings = init_output_settings()
|
50 |
|
|
|
|
|
51 |
# Main pannel
|
52 |
with st.expander("About", expanded=False):
|
53 |
st.info(
|
@@ -93,32 +112,24 @@ elif scanner_type == OUTPUT:
|
|
93 |
st_result_text = None
|
94 |
st_analysis = None
|
95 |
st_is_valid = None
|
96 |
-
st_time_delta = None
|
97 |
|
98 |
try:
|
99 |
with st.form("text_form", clear_on_submit=False):
|
100 |
submitted = st.form_submit_button("Process")
|
101 |
if submitted:
|
102 |
-
|
103 |
-
results_score = {}
|
104 |
|
105 |
-
start_time = time.monotonic()
|
106 |
if scanner_type == PROMPT:
|
107 |
-
st_result_text,
|
108 |
vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
|
109 |
)
|
110 |
elif scanner_type == OUTPUT:
|
111 |
-
st_result_text,
|
112 |
vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
|
113 |
)
|
114 |
-
end_time = time.monotonic()
|
115 |
-
st_time_delta = timedelta(seconds=end_time - start_time)
|
116 |
|
117 |
-
st_is_valid = all(
|
118 |
-
st_analysis =
|
119 |
-
{"scanner": k, "is valid": results_valid[k], "risk score": results_score[k]}
|
120 |
-
for k in results_valid
|
121 |
-
]
|
122 |
|
123 |
except Exception as e:
|
124 |
logger.error(e)
|
@@ -127,9 +138,7 @@ except Exception as e:
|
|
127 |
|
128 |
# After:
|
129 |
if st_is_valid is not None:
|
130 |
-
st.subheader(
|
131 |
-
f"Results - {'valid' if st_is_valid else 'invalid'} ({round(st_time_delta.total_seconds())} seconds)"
|
132 |
-
)
|
133 |
|
134 |
col1, col2 = st.columns(2)
|
135 |
|
|
|
1 |
import logging
|
|
|
2 |
import traceback
|
|
|
3 |
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
6 |
+
from llm_guard.vault import Vault
|
7 |
+
from streamlit.components.v1 import html
|
8 |
+
|
9 |
from output import init_settings as init_output_settings
|
10 |
from output import scan as scan_output
|
11 |
from prompt import init_settings as init_prompt_settings
|
12 |
from prompt import scan as scan_prompt
|
13 |
|
14 |
+
|
15 |
+
def add_google_analytics(ga4_id):
|
16 |
+
"""
|
17 |
+
Add Google Analytics 4 to a Streamlit app
|
18 |
+
"""
|
19 |
+
ga_code = f"""
|
20 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id={ga4_id}"></script>
|
21 |
+
<script>
|
22 |
+
window.dataLayer = window.dataLayer || [];
|
23 |
+
function gtag(){{dataLayer.push(arguments);}}
|
24 |
+
gtag('js', new Date());
|
25 |
+
gtag('config', '{ga4_id}');
|
26 |
+
</script>
|
27 |
+
"""
|
28 |
+
|
29 |
+
html(ga_code)
|
30 |
+
|
31 |
|
32 |
PROMPT = "prompt"
|
33 |
OUTPUT = "output"
|
|
|
65 |
elif scanner_type == OUTPUT:
|
66 |
enabled_scanners, settings = init_output_settings()
|
67 |
|
68 |
+
add_google_analytics("G-0HBVNHEZBW")
|
69 |
+
|
70 |
# Main pannel
|
71 |
with st.expander("About", expanded=False):
|
72 |
st.info(
|
|
|
112 |
st_result_text = None
|
113 |
st_analysis = None
|
114 |
st_is_valid = None
|
|
|
115 |
|
116 |
try:
|
117 |
with st.form("text_form", clear_on_submit=False):
|
118 |
submitted = st.form_submit_button("Process")
|
119 |
if submitted:
|
120 |
+
results = {}
|
|
|
121 |
|
|
|
122 |
if scanner_type == PROMPT:
|
123 |
+
st_result_text, results = scan_prompt(
|
124 |
vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
|
125 |
)
|
126 |
elif scanner_type == OUTPUT:
|
127 |
+
st_result_text, results = scan_output(
|
128 |
vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
|
129 |
)
|
|
|
|
|
130 |
|
131 |
+
st_is_valid = all(item["is_valid"] for item in results)
|
132 |
+
st_analysis = results
|
|
|
|
|
|
|
133 |
|
134 |
except Exception as e:
|
135 |
logger.error(e)
|
|
|
138 |
|
139 |
# After:
|
140 |
if st_is_valid is not None:
|
141 |
+
st.subheader(f"Results - {'valid' if st_is_valid else 'invalid'}")
|
|
|
|
|
142 |
|
143 |
col1, col2 = st.columns(2)
|
144 |
|
output.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import logging
|
|
|
|
|
2 |
from typing import Dict, List
|
3 |
|
4 |
import streamlit as st
|
5 |
-
from streamlit_tags import st_tags
|
6 |
-
|
7 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
8 |
from llm_guard.output_scanners import (
|
9 |
JSON,
|
@@ -12,11 +12,11 @@ from llm_guard.output_scanners import (
|
|
12 |
Bias,
|
13 |
Code,
|
14 |
Deanonymize,
|
|
|
15 |
Language,
|
16 |
LanguageSame,
|
17 |
MaliciousURLs,
|
18 |
NoRefusal,
|
19 |
-
Refutation,
|
20 |
Regex,
|
21 |
Relevance,
|
22 |
Sensitive,
|
@@ -25,6 +25,7 @@ from llm_guard.output_scanners.relevance import all_models as relevance_models
|
|
25 |
from llm_guard.output_scanners.sentiment import Sentiment
|
26 |
from llm_guard.output_scanners.toxicity import Toxicity
|
27 |
from llm_guard.vault import Vault
|
|
|
28 |
|
29 |
logger = logging.getLogger("llm-guard-playground")
|
30 |
|
@@ -41,7 +42,7 @@ def init_settings() -> (List, Dict):
|
|
41 |
"LanguageSame",
|
42 |
"MaliciousURLs",
|
43 |
"NoRefusal",
|
44 |
-
"
|
45 |
"Regex",
|
46 |
"Relevance",
|
47 |
"Sensitive",
|
@@ -163,7 +164,12 @@ def init_settings() -> (List, Dict):
|
|
163 |
help="The minimum number of JSON elements that should be present",
|
164 |
)
|
165 |
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
if "Language" in st_enabled_scanners:
|
169 |
st_lan_expander = st.sidebar.expander(
|
@@ -274,23 +280,23 @@ def init_settings() -> (List, Dict):
|
|
274 |
|
275 |
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
|
276 |
|
277 |
-
if "
|
278 |
-
|
279 |
-
"
|
280 |
expanded=False,
|
281 |
)
|
282 |
|
283 |
-
with
|
284 |
-
|
285 |
-
label="
|
286 |
value=0.5,
|
287 |
min_value=0.0,
|
288 |
max_value=1.0,
|
289 |
step=0.05,
|
290 |
-
key="
|
291 |
)
|
292 |
|
293 |
-
settings["
|
294 |
|
295 |
if "Regex" in st_enabled_scanners:
|
296 |
st_regex_expander = st.sidebar.expander(
|
@@ -359,7 +365,7 @@ def init_settings() -> (List, Dict):
|
|
359 |
key="sensitive_entity_types",
|
360 |
)
|
361 |
st.caption(
|
362 |
-
"Check all supported entities: https://
|
363 |
)
|
364 |
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
|
365 |
st_sens_threshold = st.slider(
|
@@ -434,13 +440,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
434 |
return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
|
435 |
|
436 |
if scanner_name == "Bias":
|
437 |
-
return Bias(threshold=settings["threshold"])
|
438 |
|
439 |
if scanner_name == "Deanonymize":
|
440 |
return Deanonymize(vault=vault)
|
441 |
|
442 |
if scanner_name == "JSON":
|
443 |
-
return JSON(required_elements=settings["required_elements"])
|
444 |
|
445 |
if scanner_name == "Language":
|
446 |
return Language(valid_languages=settings["valid_languages"])
|
@@ -458,16 +464,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
458 |
elif mode == "denied":
|
459 |
denied_languages = settings["languages"]
|
460 |
|
461 |
-
return Code(allowed=allowed_languages, denied=denied_languages)
|
462 |
|
463 |
if scanner_name == "MaliciousURLs":
|
464 |
-
return MaliciousURLs(threshold=settings["threshold"])
|
465 |
|
466 |
if scanner_name == "NoRefusal":
|
467 |
return NoRefusal(threshold=settings["threshold"])
|
468 |
|
469 |
-
if scanner_name == "
|
470 |
-
return
|
471 |
|
472 |
if scanner_name == "Regex":
|
473 |
match_type = settings["type"]
|
@@ -491,13 +497,14 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
491 |
entity_types=settings["entity_types"],
|
492 |
redact=settings["redact"],
|
493 |
threshold=settings["threshold"],
|
|
|
494 |
)
|
495 |
|
496 |
if scanner_name == "Sentiment":
|
497 |
return Sentiment(threshold=settings["threshold"])
|
498 |
|
499 |
if scanner_name == "Toxicity":
|
500 |
-
return Toxicity(threshold=settings["threshold"])
|
501 |
|
502 |
raise ValueError("Unknown scanner name")
|
503 |
|
@@ -509,10 +516,9 @@ def scan(
|
|
509 |
prompt: str,
|
510 |
text: str,
|
511 |
fail_fast: bool = False,
|
512 |
-
) -> (str,
|
513 |
sanitized_output = text
|
514 |
-
|
515 |
-
results_score = {}
|
516 |
|
517 |
status_text = "Scanning prompt..."
|
518 |
if fail_fast:
|
@@ -524,13 +530,23 @@ def scan(
|
|
524 |
scanner = get_scanner(
|
525 |
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
|
526 |
)
|
|
|
|
|
527 |
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
|
528 |
-
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
if fail_fast and not is_valid:
|
532 |
break
|
533 |
|
534 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
535 |
|
536 |
-
return sanitized_output,
|
|
|
1 |
import logging
|
2 |
+
import time
|
3 |
+
from datetime import timedelta
|
4 |
from typing import Dict, List
|
5 |
|
6 |
import streamlit as st
|
|
|
|
|
7 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
8 |
from llm_guard.output_scanners import (
|
9 |
JSON,
|
|
|
12 |
Bias,
|
13 |
Code,
|
14 |
Deanonymize,
|
15 |
+
FactualConsistency,
|
16 |
Language,
|
17 |
LanguageSame,
|
18 |
MaliciousURLs,
|
19 |
NoRefusal,
|
|
|
20 |
Regex,
|
21 |
Relevance,
|
22 |
Sensitive,
|
|
|
25 |
from llm_guard.output_scanners.sentiment import Sentiment
|
26 |
from llm_guard.output_scanners.toxicity import Toxicity
|
27 |
from llm_guard.vault import Vault
|
28 |
+
from streamlit_tags import st_tags
|
29 |
|
30 |
logger = logging.getLogger("llm-guard-playground")
|
31 |
|
|
|
42 |
"LanguageSame",
|
43 |
"MaliciousURLs",
|
44 |
"NoRefusal",
|
45 |
+
"FactualConsistency",
|
46 |
"Regex",
|
47 |
"Relevance",
|
48 |
"Sensitive",
|
|
|
164 |
help="The minimum number of JSON elements that should be present",
|
165 |
)
|
166 |
|
167 |
+
st_json_repair = st.checkbox("Repair", value=False, help="Attempt to repair the JSON")
|
168 |
+
|
169 |
+
settings["JSON"] = {
|
170 |
+
"required_elements": st_json_required_elements,
|
171 |
+
"repair": st_json_repair,
|
172 |
+
}
|
173 |
|
174 |
if "Language" in st_enabled_scanners:
|
175 |
st_lan_expander = st.sidebar.expander(
|
|
|
280 |
|
281 |
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
|
282 |
|
283 |
+
if "FactualConsistency" in st_enabled_scanners:
|
284 |
+
st_fc_expander = st.sidebar.expander(
|
285 |
+
"FactualConsistency",
|
286 |
expanded=False,
|
287 |
)
|
288 |
|
289 |
+
with st_fc_expander:
|
290 |
+
st_fc_minimum_score = st.slider(
|
291 |
+
label="Minimum score",
|
292 |
value=0.5,
|
293 |
min_value=0.0,
|
294 |
max_value=1.0,
|
295 |
step=0.05,
|
296 |
+
key="fc_threshold",
|
297 |
)
|
298 |
|
299 |
+
settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score}
|
300 |
|
301 |
if "Regex" in st_enabled_scanners:
|
302 |
st_regex_expander = st.sidebar.expander(
|
|
|
365 |
key="sensitive_entity_types",
|
366 |
)
|
367 |
st.caption(
|
368 |
+
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
|
369 |
)
|
370 |
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
|
371 |
st_sens_threshold = st.slider(
|
|
|
440 |
return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
|
441 |
|
442 |
if scanner_name == "Bias":
|
443 |
+
return Bias(threshold=settings["threshold"], use_onnx=True)
|
444 |
|
445 |
if scanner_name == "Deanonymize":
|
446 |
return Deanonymize(vault=vault)
|
447 |
|
448 |
if scanner_name == "JSON":
|
449 |
+
return JSON(required_elements=settings["required_elements"], repair=settings["repair"])
|
450 |
|
451 |
if scanner_name == "Language":
|
452 |
return Language(valid_languages=settings["valid_languages"])
|
|
|
464 |
elif mode == "denied":
|
465 |
denied_languages = settings["languages"]
|
466 |
|
467 |
+
return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
|
468 |
|
469 |
if scanner_name == "MaliciousURLs":
|
470 |
+
return MaliciousURLs(threshold=settings["threshold"], use_onnx=True)
|
471 |
|
472 |
if scanner_name == "NoRefusal":
|
473 |
return NoRefusal(threshold=settings["threshold"])
|
474 |
|
475 |
+
if scanner_name == "FactualConsistency":
|
476 |
+
return FactualConsistency(minimum_score=settings["minimum_score"])
|
477 |
|
478 |
if scanner_name == "Regex":
|
479 |
match_type = settings["type"]
|
|
|
497 |
entity_types=settings["entity_types"],
|
498 |
redact=settings["redact"],
|
499 |
threshold=settings["threshold"],
|
500 |
+
use_onnx=True,
|
501 |
)
|
502 |
|
503 |
if scanner_name == "Sentiment":
|
504 |
return Sentiment(threshold=settings["threshold"])
|
505 |
|
506 |
if scanner_name == "Toxicity":
|
507 |
+
return Toxicity(threshold=settings["threshold"], use_onnx=True)
|
508 |
|
509 |
raise ValueError("Unknown scanner name")
|
510 |
|
|
|
516 |
prompt: str,
|
517 |
text: str,
|
518 |
fail_fast: bool = False,
|
519 |
+
) -> (str, List[Dict[str, any]]):
|
520 |
sanitized_output = text
|
521 |
+
results = []
|
|
|
522 |
|
523 |
status_text = "Scanning prompt..."
|
524 |
if fail_fast:
|
|
|
530 |
scanner = get_scanner(
|
531 |
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
|
532 |
)
|
533 |
+
|
534 |
+
start_time = time.monotonic()
|
535 |
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
|
536 |
+
end_time = time.monotonic()
|
537 |
+
|
538 |
+
results.append(
|
539 |
+
{
|
540 |
+
"scanner": scanner_name,
|
541 |
+
"is_valid": is_valid,
|
542 |
+
"risk_score": risk_score,
|
543 |
+
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
|
544 |
+
}
|
545 |
+
)
|
546 |
|
547 |
if fail_fast and not is_valid:
|
548 |
break
|
549 |
|
550 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
551 |
|
552 |
+
return sanitized_output, results
|
prompt.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import logging
|
|
|
|
|
2 |
from typing import Dict, List
|
3 |
|
4 |
import streamlit as st
|
5 |
-
from streamlit_tags import st_tags
|
6 |
-
|
7 |
from llm_guard.input_scanners import (
|
8 |
Anonymize,
|
9 |
BanSubstrings,
|
@@ -11,7 +11,6 @@ from llm_guard.input_scanners import (
|
|
11 |
Code,
|
12 |
Language,
|
13 |
PromptInjection,
|
14 |
-
PromptInjectionV2,
|
15 |
Regex,
|
16 |
Secrets,
|
17 |
Sentiment,
|
@@ -19,8 +18,9 @@ from llm_guard.input_scanners import (
|
|
19 |
Toxicity,
|
20 |
)
|
21 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
22 |
-
from llm_guard.input_scanners.
|
23 |
from llm_guard.vault import Vault
|
|
|
24 |
|
25 |
logger = logging.getLogger("llm-guard-playground")
|
26 |
|
@@ -33,7 +33,6 @@ def init_settings() -> (List, Dict):
|
|
33 |
"Code",
|
34 |
"Language",
|
35 |
"PromptInjection",
|
36 |
-
"PromptInjectionV2",
|
37 |
"Regex",
|
38 |
"Secrets",
|
39 |
"Sentiment",
|
@@ -67,7 +66,7 @@ def init_settings() -> (List, Dict):
|
|
67 |
key="anon_entity_types",
|
68 |
)
|
69 |
st.caption(
|
70 |
-
"Check all supported entities: https://
|
71 |
)
|
72 |
st_anon_hidden_names = st_tags(
|
73 |
label="Hidden names to be anonymized",
|
@@ -101,11 +100,6 @@ def init_settings() -> (List, Dict):
|
|
101 |
step=0.1,
|
102 |
key="anon_threshold",
|
103 |
)
|
104 |
-
st_anon_recognizer = st.selectbox(
|
105 |
-
"Recognizer",
|
106 |
-
[RECOGNIZER_SPACY_EN_PII_DISTILBERT, RECOGNIZER_SPACY_EN_PII_FAST],
|
107 |
-
index=1,
|
108 |
-
)
|
109 |
|
110 |
settings["Anonymize"] = {
|
111 |
"entity_types": st_anon_entity_types,
|
@@ -114,7 +108,6 @@ def init_settings() -> (List, Dict):
|
|
114 |
"preamble": st_anon_preamble,
|
115 |
"use_faker": st_anon_use_faker,
|
116 |
"threshold": st_anon_threshold,
|
117 |
-
"recognizer": st_anon_recognizer,
|
118 |
}
|
119 |
|
120 |
if "BanSubstrings" in st_enabled_scanners:
|
@@ -286,26 +279,6 @@ def init_settings() -> (List, Dict):
|
|
286 |
"threshold": st_pi_threshold,
|
287 |
}
|
288 |
|
289 |
-
if "PromptInjectionV2" in st_enabled_scanners:
|
290 |
-
st_piv2_expander = st.sidebar.expander(
|
291 |
-
"Prompt Injection V2",
|
292 |
-
expanded=False,
|
293 |
-
)
|
294 |
-
|
295 |
-
with st_piv2_expander:
|
296 |
-
st_piv2_threshold = st.slider(
|
297 |
-
label="Threshold",
|
298 |
-
value=0.5,
|
299 |
-
min_value=0.0,
|
300 |
-
max_value=1.0,
|
301 |
-
step=0.05,
|
302 |
-
key="prompt_injection_v2_threshold",
|
303 |
-
)
|
304 |
-
|
305 |
-
settings["PromptInjectionV2"] = {
|
306 |
-
"threshold": st_piv2_threshold,
|
307 |
-
}
|
308 |
-
|
309 |
if "Regex" in st_enabled_scanners:
|
310 |
st_regex_expander = st.sidebar.expander(
|
311 |
"Regex",
|
@@ -427,7 +400,7 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
427 |
preamble=settings["preamble"],
|
428 |
use_faker=settings["use_faker"],
|
429 |
threshold=settings["threshold"],
|
430 |
-
|
431 |
)
|
432 |
|
433 |
if scanner_name == "BanSubstrings":
|
@@ -452,16 +425,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
452 |
elif mode == "denied":
|
453 |
denied_languages = settings["languages"]
|
454 |
|
455 |
-
return Code(allowed=allowed_languages, denied=denied_languages)
|
456 |
|
457 |
if scanner_name == "Language":
|
458 |
return Language(valid_languages=settings["valid_languages"])
|
459 |
|
460 |
if scanner_name == "PromptInjection":
|
461 |
-
return PromptInjection(threshold=settings["threshold"])
|
462 |
-
|
463 |
-
if scanner_name == "PromptInjectionV2":
|
464 |
-
return PromptInjectionV2(threshold=settings["threshold"])
|
465 |
|
466 |
if scanner_name == "Regex":
|
467 |
match_type = settings["type"]
|
@@ -487,17 +457,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
487 |
return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
|
488 |
|
489 |
if scanner_name == "Toxicity":
|
490 |
-
return Toxicity(threshold=settings["threshold"])
|
491 |
|
492 |
raise ValueError("Unknown scanner name")
|
493 |
|
494 |
|
495 |
def scan(
|
496 |
vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
|
497 |
-
) -> (str,
|
498 |
sanitized_prompt = text
|
499 |
-
|
500 |
-
results_score = {}
|
501 |
|
502 |
status_text = "Scanning prompt..."
|
503 |
if fail_fast:
|
@@ -507,12 +476,23 @@ def scan(
|
|
507 |
for scanner_name in enabled_scanners:
|
508 |
st.write(f"{scanner_name} scanner...")
|
509 |
scanner = get_scanner(scanner_name, vault, settings[scanner_name])
|
|
|
|
|
510 |
sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
if fail_fast and not is_valid:
|
515 |
break
|
|
|
516 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
517 |
|
518 |
-
return sanitized_prompt,
|
|
|
1 |
import logging
|
2 |
+
import time
|
3 |
+
from datetime import timedelta
|
4 |
from typing import Dict, List
|
5 |
|
6 |
import streamlit as st
|
|
|
|
|
7 |
from llm_guard.input_scanners import (
|
8 |
Anonymize,
|
9 |
BanSubstrings,
|
|
|
11 |
Code,
|
12 |
Language,
|
13 |
PromptInjection,
|
|
|
14 |
Regex,
|
15 |
Secrets,
|
16 |
Sentiment,
|
|
|
18 |
Toxicity,
|
19 |
)
|
20 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
21 |
+
from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS
|
22 |
from llm_guard.vault import Vault
|
23 |
+
from streamlit_tags import st_tags
|
24 |
|
25 |
logger = logging.getLogger("llm-guard-playground")
|
26 |
|
|
|
33 |
"Code",
|
34 |
"Language",
|
35 |
"PromptInjection",
|
|
|
36 |
"Regex",
|
37 |
"Secrets",
|
38 |
"Sentiment",
|
|
|
66 |
key="anon_entity_types",
|
67 |
)
|
68 |
st.caption(
|
69 |
+
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
|
70 |
)
|
71 |
st_anon_hidden_names = st_tags(
|
72 |
label="Hidden names to be anonymized",
|
|
|
100 |
step=0.1,
|
101 |
key="anon_threshold",
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
settings["Anonymize"] = {
|
105 |
"entity_types": st_anon_entity_types,
|
|
|
108 |
"preamble": st_anon_preamble,
|
109 |
"use_faker": st_anon_use_faker,
|
110 |
"threshold": st_anon_threshold,
|
|
|
111 |
}
|
112 |
|
113 |
if "BanSubstrings" in st_enabled_scanners:
|
|
|
279 |
"threshold": st_pi_threshold,
|
280 |
}
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
if "Regex" in st_enabled_scanners:
|
283 |
st_regex_expander = st.sidebar.expander(
|
284 |
"Regex",
|
|
|
400 |
preamble=settings["preamble"],
|
401 |
use_faker=settings["use_faker"],
|
402 |
threshold=settings["threshold"],
|
403 |
+
use_onnx=True,
|
404 |
)
|
405 |
|
406 |
if scanner_name == "BanSubstrings":
|
|
|
425 |
elif mode == "denied":
|
426 |
denied_languages = settings["languages"]
|
427 |
|
428 |
+
return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
|
429 |
|
430 |
if scanner_name == "Language":
|
431 |
return Language(valid_languages=settings["valid_languages"])
|
432 |
|
433 |
if scanner_name == "PromptInjection":
|
434 |
+
return PromptInjection(threshold=settings["threshold"], models=PI_ALL_MODELS, use_onnx=True)
|
|
|
|
|
|
|
435 |
|
436 |
if scanner_name == "Regex":
|
437 |
match_type = settings["type"]
|
|
|
457 |
return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
|
458 |
|
459 |
if scanner_name == "Toxicity":
|
460 |
+
return Toxicity(threshold=settings["threshold"], use_onnx=True)
|
461 |
|
462 |
raise ValueError("Unknown scanner name")
|
463 |
|
464 |
|
465 |
def scan(
|
466 |
vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
|
467 |
+
) -> (str, List[Dict[str, any]]):
|
468 |
sanitized_prompt = text
|
469 |
+
results = []
|
|
|
470 |
|
471 |
status_text = "Scanning prompt..."
|
472 |
if fail_fast:
|
|
|
476 |
for scanner_name in enabled_scanners:
|
477 |
st.write(f"{scanner_name} scanner...")
|
478 |
scanner = get_scanner(scanner_name, vault, settings[scanner_name])
|
479 |
+
|
480 |
+
start_time = time.monotonic()
|
481 |
sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
|
482 |
+
end_time = time.monotonic()
|
483 |
+
|
484 |
+
results.append(
|
485 |
+
{
|
486 |
+
"scanner": scanner_name,
|
487 |
+
"is_valid": is_valid,
|
488 |
+
"risk_score": risk_score,
|
489 |
+
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
|
490 |
+
}
|
491 |
+
)
|
492 |
|
493 |
if fail_fast and not is_valid:
|
494 |
break
|
495 |
+
|
496 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
497 |
|
498 |
+
return sanitized_prompt, results
|
requirements.txt
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
|
2 |
-
llm-guard==0.3.
|
3 |
-
pandas==2.1.
|
4 |
-
streamlit==1.
|
5 |
streamlit-tags==1.2.8
|
6 |
-
https://huggingface.co/beki/en_spacy_pii_fast/resolve/main/en_spacy_pii_fast-any-py3-none-any.whl
|
|
|
1 |
+
llm-guard==0.3.1
|
2 |
+
llm-guard[onnxruntime]==0.3.1
|
3 |
+
pandas==2.1.2
|
4 |
+
streamlit==1.28.1
|
5 |
streamlit-tags==1.2.8
|
|