You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					255 lines
				
				8.5 KiB
			
		
		
			
		
	
	
					255 lines
				
				8.5 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import textwrap as tr
							 | 
						||
| 
								 | 
							
								from typing import List, Optional
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import matplotlib.pyplot as plt
							 | 
						||
| 
								 | 
							
								import plotly.express as px
							 | 
						||
| 
								 | 
							
								from scipy import spatial
							 | 
						||
| 
								 | 
							
								from sklearn.decomposition import PCA
							 | 
						||
| 
								 | 
							
								from sklearn.manifold import TSNE
							 | 
						||
| 
								 | 
							
								from sklearn.metrics import average_precision_score, precision_recall_curve
							 | 
						||
| 
								 | 
							
								from tenacity import retry, stop_after_attempt, wait_random_exponential
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import openai
							 | 
						||
| 
								 | 
							
								from openai.datalib import numpy as np
							 | 
						||
| 
								 | 
							
								from openai.datalib import pandas as pd
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
							 | 
						||
| 
								 | 
							
								def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # replace newlines, which can negatively affect performance.
							 | 
						||
| 
								 | 
							
								    text = text.replace("\n", " ")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
							 | 
						||
| 
								 | 
							
								async def aget_embedding(
							 | 
						||
| 
								 | 
							
								    text: str, engine="text-similarity-davinci-001"
							 | 
						||
| 
								 | 
							
								) -> List[float]:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # replace newlines, which can negatively affect performance.
							 | 
						||
| 
								 | 
							
								    text = text.replace("\n", " ")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][
							 | 
						||
| 
								 | 
							
								        "embedding"
							 | 
						||
| 
								 | 
							
								    ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
							 | 
						||
| 
								 | 
							
								def get_embeddings(
							 | 
						||
| 
								 | 
							
								    list_of_text: List[str], engine="text-similarity-babbage-001"
							 | 
						||
| 
								 | 
							
								) -> List[List[float]]:
							 | 
						||
| 
								 | 
							
								    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # replace newlines, which can negatively affect performance.
							 | 
						||
| 
								 | 
							
								    list_of_text = [text.replace("\n", " ") for text in list_of_text]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    data = openai.Embedding.create(input=list_of_text, engine=engine).data
							 | 
						||
| 
								 | 
							
								    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
							 | 
						||
| 
								 | 
							
								    return [d["embedding"] for d in data]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
							 | 
						||
| 
								 | 
							
								async def aget_embeddings(
							 | 
						||
| 
								 | 
							
								    list_of_text: List[str], engine="text-similarity-babbage-001"
							 | 
						||
| 
								 | 
							
								) -> List[List[float]]:
							 | 
						||
| 
								 | 
							
								    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # replace newlines, which can negatively affect performance.
							 | 
						||
| 
								 | 
							
								    list_of_text = [text.replace("\n", " ") for text in list_of_text]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data
							 | 
						||
| 
								 | 
							
								    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
							 | 
						||
| 
								 | 
							
								    return [d["embedding"] for d in data]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def cosine_similarity(a, b):
							 | 
						||
| 
								 | 
							
								    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def plot_multiclass_precision_recall(
							 | 
						||
| 
								 | 
							
								    y_score, y_true_untransformed, class_list, classifier_name
							 | 
						||
| 
								 | 
							
								):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    n_classes = len(class_list)
							 | 
						||
| 
								 | 
							
								    y_true = pd.concat(
							 | 
						||
| 
								 | 
							
								        [(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1
							 | 
						||
| 
								 | 
							
								    ).values
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # For each class
							 | 
						||
| 
								 | 
							
								    precision = dict()
							 | 
						||
| 
								 | 
							
								    recall = dict()
							 | 
						||
| 
								 | 
							
								    average_precision = dict()
							 | 
						||
| 
								 | 
							
								    for i in range(n_classes):
							 | 
						||
| 
								 | 
							
								        precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i])
							 | 
						||
| 
								 | 
							
								        average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # A "micro-average": quantifying score on all classes jointly
							 | 
						||
| 
								 | 
							
								    precision_micro, recall_micro, _ = precision_recall_curve(
							 | 
						||
| 
								 | 
							
								        y_true.ravel(), y_score.ravel()
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    average_precision_micro = average_precision_score(y_true, y_score, average="micro")
							 | 
						||
| 
								 | 
							
								    print(
							 | 
						||
| 
								 | 
							
								        str(classifier_name)
							 | 
						||
| 
								 | 
							
								        + " - Average precision score over all classes: {0:0.2f}".format(
							 | 
						||
| 
								 | 
							
								            average_precision_micro
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # setup plot details
							 | 
						||
| 
								 | 
							
								    plt.figure(figsize=(9, 10))
							 | 
						||
| 
								 | 
							
								    f_scores = np.linspace(0.2, 0.8, num=4)
							 | 
						||
| 
								 | 
							
								    lines = []
							 | 
						||
| 
								 | 
							
								    labels = []
							 | 
						||
| 
								 | 
							
								    for f_score in f_scores:
							 | 
						||
| 
								 | 
							
								        x = np.linspace(0.01, 1)
							 | 
						||
| 
								 | 
							
								        y = f_score * x / (2 * x - f_score)
							 | 
						||
| 
								 | 
							
								        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
							 | 
						||
| 
								 | 
							
								        plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    lines.append(l)
							 | 
						||
| 
								 | 
							
								    labels.append("iso-f1 curves")
							 | 
						||
| 
								 | 
							
								    (l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2)
							 | 
						||
| 
								 | 
							
								    lines.append(l)
							 | 
						||
| 
								 | 
							
								    labels.append(
							 | 
						||
| 
								 | 
							
								        "average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro)
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for i in range(n_classes):
							 | 
						||
| 
								 | 
							
								        (l,) = plt.plot(recall[i], precision[i], lw=2)
							 | 
						||
| 
								 | 
							
								        lines.append(l)
							 | 
						||
| 
								 | 
							
								        labels.append(
							 | 
						||
| 
								 | 
							
								            "Precision-recall for class `{0}` (auprc = {1:0.2f})"
							 | 
						||
| 
								 | 
							
								            "".format(class_list[i], average_precision[i])
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    fig = plt.gcf()
							 | 
						||
| 
								 | 
							
								    fig.subplots_adjust(bottom=0.25)
							 | 
						||
| 
								 | 
							
								    plt.xlim([0.0, 1.0])
							 | 
						||
| 
								 | 
							
								    plt.ylim([0.0, 1.05])
							 | 
						||
| 
								 | 
							
								    plt.xlabel("Recall")
							 | 
						||
| 
								 | 
							
								    plt.ylabel("Precision")
							 | 
						||
| 
								 | 
							
								    plt.title(f"{classifier_name}: Precision-Recall curve for each class")
							 | 
						||
| 
								 | 
							
								    plt.legend(lines, labels)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def distances_from_embeddings(
							 | 
						||
| 
								 | 
							
								    query_embedding: List[float],
							 | 
						||
| 
								 | 
							
								    embeddings: List[List[float]],
							 | 
						||
| 
								 | 
							
								    distance_metric="cosine",
							 | 
						||
| 
								 | 
							
								) -> List[List]:
							 | 
						||
| 
								 | 
							
								    """Return the distances between a query embedding and a list of embeddings."""
							 | 
						||
| 
								 | 
							
								    distance_metrics = {
							 | 
						||
| 
								 | 
							
								        "cosine": spatial.distance.cosine,
							 | 
						||
| 
								 | 
							
								        "L1": spatial.distance.cityblock,
							 | 
						||
| 
								 | 
							
								        "L2": spatial.distance.euclidean,
							 | 
						||
| 
								 | 
							
								        "Linf": spatial.distance.chebyshev,
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    distances = [
							 | 
						||
| 
								 | 
							
								        distance_metrics[distance_metric](query_embedding, embedding)
							 | 
						||
| 
								 | 
							
								        for embedding in embeddings
							 | 
						||
| 
								 | 
							
								    ]
							 | 
						||
| 
								 | 
							
								    return distances
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray:
							 | 
						||
| 
								 | 
							
								    """Return a list of indices of nearest neighbors from a list of distances."""
							 | 
						||
| 
								 | 
							
								    return np.argsort(distances)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def pca_components_from_embeddings(
							 | 
						||
| 
								 | 
							
								    embeddings: List[List[float]], n_components=2
							 | 
						||
| 
								 | 
							
								) -> np.ndarray:
							 | 
						||
| 
								 | 
							
								    """Return the PCA components of a list of embeddings."""
							 | 
						||
| 
								 | 
							
								    pca = PCA(n_components=n_components)
							 | 
						||
| 
								 | 
							
								    array_of_embeddings = np.array(embeddings)
							 | 
						||
| 
								 | 
							
								    return pca.fit_transform(array_of_embeddings)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def tsne_components_from_embeddings(
							 | 
						||
| 
								 | 
							
								    embeddings: List[List[float]], n_components=2, **kwargs
							 | 
						||
| 
								 | 
							
								) -> np.ndarray:
							 | 
						||
| 
								 | 
							
								    """Returns t-SNE components of a list of embeddings."""
							 | 
						||
| 
								 | 
							
								    # use better defaults if not specified
							 | 
						||
| 
								 | 
							
								    if "init" not in kwargs.keys():
							 | 
						||
| 
								 | 
							
								        kwargs["init"] = "pca"
							 | 
						||
| 
								 | 
							
								    if "learning_rate" not in kwargs.keys():
							 | 
						||
| 
								 | 
							
								        kwargs["learning_rate"] = "auto"
							 | 
						||
| 
								 | 
							
								    tsne = TSNE(n_components=n_components, **kwargs)
							 | 
						||
| 
								 | 
							
								    array_of_embeddings = np.array(embeddings)
							 | 
						||
| 
								 | 
							
								    return tsne.fit_transform(array_of_embeddings)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def chart_from_components(
							 | 
						||
| 
								 | 
							
								    components: np.ndarray,
							 | 
						||
| 
								 | 
							
								    labels: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								    strings: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								    x_title="Component 0",
							 | 
						||
| 
								 | 
							
								    y_title="Component 1",
							 | 
						||
| 
								 | 
							
								    mark_size=5,
							 | 
						||
| 
								 | 
							
								    **kwargs,
							 | 
						||
| 
								 | 
							
								):
							 | 
						||
| 
								 | 
							
								    """Return an interactive 2D chart of embedding components."""
							 | 
						||
| 
								 | 
							
								    empty_list = ["" for _ in components]
							 | 
						||
| 
								 | 
							
								    data = pd.DataFrame(
							 | 
						||
| 
								 | 
							
								        {
							 | 
						||
| 
								 | 
							
								            x_title: components[:, 0],
							 | 
						||
| 
								 | 
							
								            y_title: components[:, 1],
							 | 
						||
| 
								 | 
							
								            "label": labels if labels else empty_list,
							 | 
						||
| 
								 | 
							
								            "string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
							 | 
						||
| 
								 | 
							
								            if strings
							 | 
						||
| 
								 | 
							
								            else empty_list,
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    chart = px.scatter(
							 | 
						||
| 
								 | 
							
								        data,
							 | 
						||
| 
								 | 
							
								        x=x_title,
							 | 
						||
| 
								 | 
							
								        y=y_title,
							 | 
						||
| 
								 | 
							
								        color="label" if labels else None,
							 | 
						||
| 
								 | 
							
								        symbol="label" if labels else None,
							 | 
						||
| 
								 | 
							
								        hover_data=["string"] if strings else None,
							 | 
						||
| 
								 | 
							
								        **kwargs,
							 | 
						||
| 
								 | 
							
								    ).update_traces(marker=dict(size=mark_size))
							 | 
						||
| 
								 | 
							
								    return chart
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def chart_from_components_3D(
							 | 
						||
| 
								 | 
							
								    components: np.ndarray,
							 | 
						||
| 
								 | 
							
								    labels: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								    strings: Optional[List[str]] = None,
							 | 
						||
| 
								 | 
							
								    x_title: str = "Component 0",
							 | 
						||
| 
								 | 
							
								    y_title: str = "Component 1",
							 | 
						||
| 
								 | 
							
								    z_title: str = "Compontent 2",
							 | 
						||
| 
								 | 
							
								    mark_size: int = 5,
							 | 
						||
| 
								 | 
							
								    **kwargs,
							 | 
						||
| 
								 | 
							
								):
							 | 
						||
| 
								 | 
							
								    """Return an interactive 3D chart of embedding components."""
							 | 
						||
| 
								 | 
							
								    empty_list = ["" for _ in components]
							 | 
						||
| 
								 | 
							
								    data = pd.DataFrame(
							 | 
						||
| 
								 | 
							
								        {
							 | 
						||
| 
								 | 
							
								            x_title: components[:, 0],
							 | 
						||
| 
								 | 
							
								            y_title: components[:, 1],
							 | 
						||
| 
								 | 
							
								            z_title: components[:, 2],
							 | 
						||
| 
								 | 
							
								            "label": labels if labels else empty_list,
							 | 
						||
| 
								 | 
							
								            "string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
							 | 
						||
| 
								 | 
							
								            if strings
							 | 
						||
| 
								 | 
							
								            else empty_list,
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    chart = px.scatter_3d(
							 | 
						||
| 
								 | 
							
								        data,
							 | 
						||
| 
								 | 
							
								        x=x_title,
							 | 
						||
| 
								 | 
							
								        y=y_title,
							 | 
						||
| 
								 | 
							
								        z=z_title,
							 | 
						||
| 
								 | 
							
								        color="label" if labels else None,
							 | 
						||
| 
								 | 
							
								        symbol="label" if labels else None,
							 | 
						||
| 
								 | 
							
								        hover_data=["string"] if strings else None,
							 | 
						||
| 
								 | 
							
								        **kwargs,
							 | 
						||
| 
								 | 
							
								    ).update_traces(marker=dict(size=mark_size))
							 | 
						||
| 
								 | 
							
								    return chart
							 |