Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://github.com/huan085128/Gather-and-Bind
07 February 2025, 11:51:28 UTC
  • Code
  • Branches (1)
  • Releases (0)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/master
    No releases to show
  • 8862dd3
  • /
  • ptp_utils.py
Raw File Download Save again
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge
swh:1:cnt:2a4b17f5acbe80e92365cf2589e11daa84003769
origin badgedirectory badge
swh:1:dir:8862dd3652cf95ad66334b6ab2aee89b68617f27
origin badgerevision badge
swh:1:rev:de045e02ef1014e36fb9ecf4882d6efd5d266415
origin badgesnapshot badge
swh:1:snp:d0cb33e9cd58befa97b803328695a51de205e96f

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: de045e02ef1014e36fb9ecf4882d6efd5d266415 authored by 傅欢 on 07 February 2025, 07:49:42 UTC
update readme
Tip revision: de045e0
ptp_utils.py
import abc
import torch
import cv2
import re
import pprint
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display
from typing import Dict, Union, Tuple, List
import warnings
import spacy

from diffusers.models.cross_attention import CrossAttention
from transformers import CLIPTokenizer


def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
    h, w, c = image.shape
    offset = int(h * .2)
    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    img[:h] = image
    textsize = cv2.getTextSize(text, font, 1, 2)[0]
    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
    cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
    return img


def view_images(images, num_rows=1, offset_ratio=0.02, display_image=True,
                cmap=None, save_path=None):
    """
    Displays a list of images in a grid using Matplotlib with optional color mapping.
    """
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape[:2] + (3,), dtype=np.uint8) * 255
    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255

    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
                i * num_cols + j]
            
    if cmap is not None:
        image_ = cv2.cvtColor(image_, cv2.COLOR_RGB2GRAY)
    if display_image:
        plt.imshow(image_, cmap=cmap)
        plt.axis('off')
        plt.show()
    if save_path is not None:
        plt.imsave(save_path, image_, cmap=cmap)


class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return 0

    @abc.abstractmethod
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            self.forward(attn, is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        return attn


class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [], "mid_self": [], "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        self.attention_store = self.step_store
        if self.save_global_store:
            with torch.no_grad():
                if len(self.global_store) == 0:
                    self.global_store = self.step_store
                else:
                    for key in self.global_store:
                        for i in range(len(self.global_store[key])):
                            self.global_store[key][i] += self.step_store[key][i].detach()
        self.step_store = self.get_empty_store()
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = self.attention_store
        return average_attention

    def get_average_global_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
                             self.attention_store}
        return average_attention

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}
        self.global_store = {}

    def __init__(self, save_global_store=False):
        '''
        Initialize an empty AttentionStore
        :param step_index: used to visualize only a specific step in the diffusion process
        '''
        super(AttentionStore, self).__init__()
        self.save_global_store = save_global_store
        self.step_store = self.get_empty_store()
        self.attention_store = {}
        self.global_store = {}
        self.curr_step_index = 0


class AUGCrossAttnProcessor:

    def __init__(self, controller, place_in_unet):
        super().__init__()
        self.controller = controller
        self.place_in_unet = place_in_unet

    def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(
            attention_mask, sequence_length)

        query = attn.to_q(hidden_states)

        is_cross = encoder_hidden_states is not None
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query).contiguous()
        key = attn.head_to_batch_dim(key).contiguous()
        value = attn.head_to_batch_dim(value).contiguous()

        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        # one line change
        self.controller(attention_probs, is_cross, self.place_in_unet)

        hidden_states = torch.bmm(attention_probs, value)
        # hidden_states = xformers.ops.memory_efficient_attention(
        #     query, key, value, attn_bias=attention_mask, op=self.attention_op
        # )
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states


def get_indices_to_alter(stable, prompt: str) -> List[int]:
    token_idx_to_word = {idx: stable.tokenizer.decode(t)
                         for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
                         if 0 <= idx < len(stable.tokenizer(prompt)['input_ids'])}
    pprint.pprint(token_idx_to_word)
    token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
                          "alter (e.g., 2,5): ")
    token_indices = [int(i) for i in token_indices.split(",")]
    print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
    return token_indices

def get_ranges_to_alter(stable, prompt: str) -> List[Tuple[int, int]]:
    token_idx_to_word = {idx: stable.tokenizer.decode(t)
                         for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
                         if 0 <= idx < len(stable.tokenizer(prompt)['input_ids'])}
    pprint.pprint(token_idx_to_word)

    token_ranges = input(
        "Please enter token ranges to alter as '(start1, end1), (start2, end2)': ")

    # 使用正则表达式提取括号内的数字
    pattern = r'\((\d+),\s*(\d+)\)'
    matches = re.findall(pattern, token_ranges)

    selected_ranges = []
    for match in matches:
        start, end = map(int, match)
        selected_ranges.append((start, end))

    selected_tokens = [token_idx_to_word[i] for start,
                       end in selected_ranges for i in range(start, end + 1)]
    print(f"Altering tokens: {selected_tokens}")

    return selected_ranges


def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16):
    # create heatmap from mask on image

    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
    image = np.array(image)

    image_relevance = image_relevance.reshape(
        1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
    # because float16 precision interpolation is not supported on cpu
    image_relevance = image_relevance.cuda()
    image_relevance = torch.nn.functional.interpolate(
        image_relevance, size=relevnace_res ** 2, mode='bilinear')
    image_relevance = image_relevance.cpu()  # send it back to cpu
    image_relevance = (image_relevance - image_relevance.min()) / \
        (image_relevance.max() - image_relevance.min())
    image_relevance = image_relevance.reshape(
        relevnace_res ** 2, relevnace_res ** 2)
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def extract_noun_phrases(prompt: str, parser: spacy.language.Language) -> tuple[list, list]:
    doc = parser(prompt)
    subtrees = []
    modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
    nouns = []

    for w in doc:
        if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
            continue
        subtree = []
        stack = []
        for child in w.children:
            if child.dep_ in modifiers:
                subtree.append(child)
                stack.extend(child.children)

        while stack:
            node = stack.pop()
            if node.dep_ in modifiers or node.dep_ == "conj":
                subtree.append(node)
                stack.extend(node.children)
        if subtree:
            subtree.append(w)
            subtrees.append(subtree)

        if w.pos_ in ["NOUN", "PROPN"]:
            nouns.append(w)

    return nouns, subtrees


def get_indices(tokenizer: CLIPTokenizer, prompt: str) -> Dict[str, int]:
    ids = tokenizer(prompt).input_ids
    token_idx_to_word = {idx: tokenizer.decode(t)
                         for idx, t in enumerate(ids)
                         if 0 <= idx < len(tokenizer(prompt)['input_ids'])}
    return token_idx_to_word


def get_word_nouns_indices(output_sublist: tuple[list, list], token_idx_to_word: Dict[str, int]) -> list:
    words_phrase_indices = []
    words_phrase = [word for word in output_sublist]
    words_phrase = [token.text if isinstance(
        token, spacy.tokens.token.Token) else token for token in words_phrase]

    for k, v in token_idx_to_word.items():
        if v in words_phrase:
            words_phrase_indices.append(k)

    return words_phrase_indices


def get_word_phrase_indices(output_sublist: tuple[list, list], token_idx_to_word: Dict[str, int]) -> list:
    words_phrase_indices = []
    words_phrase = [word for word in output_sublist]
    words_phrase = [token.text if isinstance(
        token, spacy.tokens.token.Token) else token for token in words_phrase]

    sublist_len = len(words_phrase)
    keys = list(token_idx_to_word.keys())
    values = list(token_idx_to_word.values())
    for i in range(len(values) - sublist_len + 1):
        if values[i:i+sublist_len] == words_phrase:
            return keys[i:i+sublist_len]
    return None


def match_indices(tokenizer: CLIPTokenizer, prompt: str, paired_indices: tuple[list, list]) -> tuple[list, list]:
    token_idx_to_word = get_indices(tokenizer, prompt)
    prompt_s = prompt.split()
    if len(prompt_s) != len(token_idx_to_word) - 2:
        warnings.warn("Please change the prompt to match the tokenizer", Warning)
        words_nouns_indices = get_word_nouns_indices(paired_indices[0], token_idx_to_word)
        return words_nouns_indices, []

    words_nouns_indices = get_word_nouns_indices(paired_indices[0], token_idx_to_word)

    words_phrase_indices = []
    for sublist in paired_indices[1]:
        word_indices = get_word_phrase_indices(sublist, token_idx_to_word)
        words_phrase_indices.append(word_indices)

    return words_nouns_indices, words_phrase_indices

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API