Skip to content

rigging.generator

Generators produce completions for a given set of messages or text.

StopReason = t.Literal['stop', 'length', 'content_filter', 'tool_calls', 'unknown'] module-attribute #

Reporting reason for generation completing.

GenerateParams #

Bases: BaseModel

Parameters for generating text using a language model.

These are designed to generally overlap with underlying APIs like litellm, but will be extended as needed.

Note

Use the extra field to pass additional parameters to the API.

api_base: str | None = None class-attribute instance-attribute #

The base URL for the API.

extra: dict[str, t.Any] = Field(default_factory=dict) class-attribute instance-attribute #

Extra parameters to be passed to the API.

frequency_penalty: float | None = None class-attribute instance-attribute #

The frequency penalty.

max_tokens: int | None = None class-attribute instance-attribute #

The maximum number of tokens to generate.

parallel_tool_calls: bool | None = None class-attribute instance-attribute #

Whether to run allow tool calls in parallel.

presence_penalty: float | None = None class-attribute instance-attribute #

The presence penalty.

seed: int | None = None class-attribute instance-attribute #

The random seed.

stop: list[str] | None = None class-attribute instance-attribute #

A list of stop sequences to stop generation at.

temperature: float | None = None class-attribute instance-attribute #

The sampling temperature.

timeout: int | None = None class-attribute instance-attribute #

The timeout for the API request.

tool_choice: ToolChoice | None = None class-attribute instance-attribute #

The tool choice to be used in the generation.

tools: list[ToolDefinition] | None = None class-attribute instance-attribute #

The tools to be used in the generation.

top_k: int | None = None class-attribute instance-attribute #

The top-k sampling parameter.

top_p: float | None = None class-attribute instance-attribute #

The nucleus sampling probability.

merge_with(*others: t.Optional[GenerateParams]) -> GenerateParams #

Apply a series of parameter overrides to the current instance and return a copy.

Parameters:

  • *others (Optional[GenerateParams], default: () ) –

    The parameters to be merged with the current instance's parameters. Can be multiple and overrides will be applied in order.

Returns:

Source code in rigging/generator/base.py
def merge_with(self, *others: t.Optional[GenerateParams]) -> GenerateParams:
    """
    Apply a series of parameter overrides to the current instance and return a copy.

    Args:
        *others: The parameters to be merged with the current instance's parameters.
            Can be multiple and overrides will be applied in order.

    Returns:
        The merged parameters instance.
    """
    if len(others) == 0 or all(p is None for p in others):
        return self

    updates: dict[str, t.Any] = {}
    for other in [o for o in others if o is not None]:
        other_dict = other.model_dump(exclude_unset=True, exclude_none=True)
        for name in other_dict.keys():
            updates[name] = getattr(other, name)

    return self.model_copy(update=updates)

to_dict() -> dict[str, t.Any] #

Convert the parameters to a dictionary.

Returns:

  • dict[str, Any]

    The parameters as a dictionary.

Source code in rigging/generator/base.py
def to_dict(self) -> dict[str, t.Any]:
    """
    Convert the parameters to a dictionary.

    Returns:
        The parameters as a dictionary.
    """
    params = self.model_dump(exclude_none=True)
    if "extra" in params:
        params.update(params.pop("extra"))
    return params

GeneratedMessage #

Bases: BaseModel

A generated message with additional generation information.

extra: dict[str, t.Any] = Field(default_factory=dict) class-attribute instance-attribute #

Any additional information from the generation.

message: Message instance-attribute #

The generated message.

stop_reason: t.Annotated[StopReason, BeforeValidator(convert_stop_reason)] = 'unknown' class-attribute instance-attribute #

The reason for stopping generation.

usage: t.Optional[Usage] = None class-attribute instance-attribute #

The usage statistics for the generation if available.

GeneratedText #

Bases: BaseModel

A generated text with additional generation information.

extra: dict[str, t.Any] = Field(default_factory=dict) class-attribute instance-attribute #

Any additional information from the generation.

stop_reason: t.Annotated[StopReason, BeforeValidator(convert_stop_reason)] = 'unknown' class-attribute instance-attribute #

The reason for stopping generation.

text: str instance-attribute #

The generated text.

usage: t.Optional[Usage] = None class-attribute instance-attribute #

The usage statistics for the generation if available.

Generator #

Bases: BaseModel

Base class for all rigging generators.

This class provides common functionality and methods for generating completion messages.

A subclass of this can implement both or one of the following:

  • generate_messages: Process a batch of messages.
  • generate_texts: Process a batch of texts.

api_key: str | None = Field(None, exclude=True) class-attribute instance-attribute #

The API key used for authentication.

model: str instance-attribute #

The model name to be used by the generator.

params: GenerateParams instance-attribute #

The parameters used for generating completion messages.

chat(messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None) -> ChatPipeline #

Build a chat pipeline with the given messages and optional params overloads.

Parameters:

Returns:

Source code in rigging/generator/base.py
def chat(
    self,
    messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None,
    params: GenerateParams | None = None,
) -> ChatPipeline:
    """
    Build a chat pipeline with the given messages and optional params overloads.

    Args:
        messages: The messages to be sent in the chat.
        params: Optional parameters for generating responses.

    Returns:
        chat pipeline to run.
    """
    from rigging.chat import ChatPipeline, WatchChatCallback

    chat_watch_callbacks = [cb for cb in self._watch_callbacks if isinstance(cb, (WatchChatCallback))]

    return ChatPipeline(
        self,
        Message.fit_as_list(messages) if messages else [],
        params=params,
        watch_callbacks=chat_watch_callbacks,
    )

complete(text: str, params: GenerateParams | None = None) -> CompletionPipeline #

Build a completion pipeline of the given text with optional param overloads.

Parameters:

  • text (str) –

    The input text to be completed.

  • params (GenerateParams | None, default: None ) –

    The parameters to be used for completion.

Returns:

Source code in rigging/generator/base.py
def complete(self, text: str, params: GenerateParams | None = None) -> CompletionPipeline:
    """
    Build a completion pipeline of the given text with optional param overloads.

    Args:
        text: The input text to be completed.
        params: The parameters to be used for completion.

    Returns:
        The completed text.
    """
    from rigging.completion import CompletionPipeline, WatchCompletionCallback

    completion_watch_callbacks = [cb for cb in self._watch_callbacks if isinstance(cb, (WatchCompletionCallback))]

    return CompletionPipeline(self, text, params=params, watch_callbacks=completion_watch_callbacks)

generate_messages(messages: t.Sequence[t.Sequence[Message]], params: t.Sequence[GenerateParams]) -> t.Sequence[GeneratedMessage] async #

Generate a batch of messages using the specified parameters.

Note

The length of params must be the same as the length of many.

Parameters:

  • messages (Sequence[Sequence[Message]]) –

    A sequence of sequences of messages.

  • params (Sequence[GenerateParams]) –

    A sequence of GenerateParams objects.

Returns:

Raises:

  • NotImplementedError

    This method is not supported by this generator.

Source code in rigging/generator/base.py
async def generate_messages(
    self,
    messages: t.Sequence[t.Sequence[Message]],
    params: t.Sequence[GenerateParams],
) -> t.Sequence[GeneratedMessage]:
    """
    Generate a batch of messages using the specified parameters.

    Note:
        The length of `params` must be the same as the length of `many`.

    Args:
        messages: A sequence of sequences of messages.
        params: A sequence of GenerateParams objects.

    Returns:
        A sequence of generated messages.

    Raises:
        NotImplementedError: This method is not supported by this generator.
    """
    raise NotImplementedError("`generate_messages` is not supported by this generator.")

generate_texts(texts: t.Sequence[str], params: t.Sequence[GenerateParams]) -> t.Sequence[GeneratedText] async #

Generate a batch of text completions using the generator.

Note

This method falls back to looping over the inputs and calling generate_text for each item.

Note

If supplied, the length of params must be the same as the length of many.

Parameters:

  • texts (Sequence[str]) –

    The input texts for generating the batch.

  • params (Sequence[GenerateParams]) –

    Additional parameters for generating each text in the batch.

Returns:

Raises:

  • NotImplementedError

    This method is not supported by this generator.

Source code in rigging/generator/base.py
async def generate_texts(
    self,
    texts: t.Sequence[str],
    params: t.Sequence[GenerateParams],
) -> t.Sequence[GeneratedText]:
    """
    Generate a batch of text completions using the generator.

    Note:
        This method falls back to looping over the inputs and calling `generate_text` for each item.

    Note:
        If supplied, the length of `params` must be the same as the length of `many`.

    Args:
        texts: The input texts for generating the batch.
        params: Additional parameters for generating each text in the batch.

    Returns:
        The generated texts.

    Raises:
        NotImplementedError: This method is not supported by this generator.
    """
    raise NotImplementedError("`generate_texts` is not supported by this generator.")

load() -> Self #

If supported, trigger underlying loading and preparation of the model.

Returns:

  • Self

    The generator.

Source code in rigging/generator/base.py
def load(self) -> Self:
    """
    If supported, trigger underlying loading and preparation of the model.

    Returns:
        The generator.
    """
    return self

prompt(func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R] #

Decorator to convert a function into a prompt bound to this generator.

See rigging.prompt.prompt for more information.

Parameters:

  • func (Callable[P, Coroutine[None, None, R]]) –

    The function to be converted into a prompt.

Returns:

  • Prompt[P, R]

    The prompt.

Source code in rigging/generator/base.py
def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
    """
    Decorator to convert a function into a prompt bound to this generator.

    See [rigging.prompt.prompt][] for more information.

    Args:
        func: The function to be converted into a prompt.

    Returns:
        The prompt.
    """
    from rigging.prompt import prompt

    return prompt(func, generator=self)

to_identifier(params: GenerateParams | None = None) -> str #

Converts the generator instance back into a rigging identifier string.

This calls rigging.generator.get_identifier with the current instance.

Parameters:

  • params (GenerateParams | None, default: None ) –

    The generation parameters.

Returns:

  • str

    The identifier string.

Source code in rigging/generator/base.py
def to_identifier(self, params: GenerateParams | None = None) -> str:
    """
    Converts the generator instance back into a rigging identifier string.

    This calls [rigging.generator.get_identifier][] with the current instance.

    Args:
        params: The generation parameters.

    Returns:
        The identifier string.
    """
    return get_identifier(self, params)

unload() -> Self #

If supported, clean up resources used by the underlying model.

Returns:

  • Self

    The generator.

Source code in rigging/generator/base.py
def unload(self) -> Self:
    """
    If supported, clean up resources used by the underlying model.

    Returns:
        The generator.
    """
    return self

watch(*callbacks: WatchCallbacks, allow_duplicates: bool = False) -> Generator #

Registers watch callbacks to be passed to any created rigging.chat.ChatPipeline or rigging.completion.CompletionPipeline.

Parameters:

  • *callbacks (WatchCallbacks, default: () ) –

    The callback functions to be executed.

  • allow_duplicates (bool, default: False ) –

    Whether to allow (seemingly) duplicate callbacks to be added.

async def log(chats: list[Chat]) -> None:
    ...

await pipeline.watch(log).run()

Returns:

  • Generator

    The current instance of the chat.

Source code in rigging/generator/base.py
def watch(self, *callbacks: WatchCallbacks, allow_duplicates: bool = False) -> Generator:
    """
    Registers watch callbacks to be passed to any created
    [rigging.chat.ChatPipeline][] or [rigging.completion.CompletionPipeline][].

    Args:
        *callbacks: The callback functions to be executed.
        allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.

    ```
    async def log(chats: list[Chat]) -> None:
        ...

    await pipeline.watch(log).run()
    ```

    Returns:
        The current instance of the chat.
    """
    for callback in callbacks:
        if allow_duplicates or callback not in self._watch_callbacks:
            self._watch_callbacks.append(callback)
    return self

wrap(func: t.Callable[[CallableT], CallableT] | None) -> Self #

If supported, wrap any underlying interior framework calls with this function.

This is useful for adding things like backoff or rate limiting.

Parameters:

  • func (Callable[[CallableT], CallableT] | None) –

    The function to wrap the calls with.

Returns:

  • Self

    The generator.

Source code in rigging/generator/base.py
def wrap(self, func: t.Callable[[CallableT], CallableT] | None) -> Self:
    """
    If supported, wrap any underlying interior framework calls with this function.

    This is useful for adding things like backoff or rate limiting.

    Args:
        func: The function to wrap the calls with.

    Returns:
        The generator.
    """
    # TODO: Not sure why mypy is complaining here
    self._wrap = func  # type: ignore [assignment]
    return self

LiteLLMGenerator #

Bases: Generator

Generator backed by the LiteLLM library.

Find more information about supported models and formats in their docs..

Note

Batching support is not performant and simply a loop over inputs.

Warning

While some providers support passing n to produce a batch of completions per request, we don't currently use this in the implementation due to it's brittle requirements.

Tip

Consider setting max_connections or [min_delay_between_requests][rigging.generator.litellm_.LiteLLMGenerator.min_delay_between_requests if you run into API limits. You can pass this directly in the generator id:

get_generator("litellm!openai/gpt-4o,max_connections=2,min_delay_between_requests=1000")

max_connections: int = 10 class-attribute instance-attribute #

How many simultaneous requests to pool at one time. This is useful to set when you run into API limits at a provider.

Set to 0 to remove the limit.

min_delay_between_requests: float = 0.0 class-attribute instance-attribute #

Minimum time (ms) between each request. This is useful to set when you run into API limits at a provider.

Usage #

Bases: BaseModel

input_tokens: int instance-attribute #

The number of input tokens.

output_tokens: int instance-attribute #

The number of output tokens.

total_tokens: int instance-attribute #

The total number of tokens processed.

chat(generator: Generator, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None) -> ChatPipeline #

Creates a chat pipeline using the given generator, messages, and params.

Parameters:

  • generator (Generator) –

    The generator to use for creating the chat.

  • messages (Sequence[MessageDict] | Sequence[Message] | MessageDict | Message | str | None, default: None ) –

    The messages to include in the chat. Can be a single message or a sequence of messages.

  • params (GenerateParams | None, default: None ) –

    Additional parameters for generating the chat.

Returns:

Source code in rigging/generator/base.py
def chat(
    generator: Generator,
    messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None,
    params: GenerateParams | None = None,
) -> ChatPipeline:
    """
    Creates a chat pipeline using the given generator, messages, and params.

    Args:
        generator: The generator to use for creating the chat.
        messages:
            The messages to include in the chat. Can be a single message or a sequence of messages.
        params: Additional parameters for generating the chat.

    Returns:
        chat pipeline to run.
    """
    return generator.chat(messages, params)

get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator #

Get a generator by an identifier string. Uses LiteLLM by default.

Identifier strings are formatted like <provider>!<model>,<**kwargs>

(provider is optional andif not specified)

Examples:

  • "gpt-3.5-turbo" -> LiteLLMGenerator(model="gpt-3.5-turbo")
  • "litellm!claude-2.1" -> LiteLLMGenerator(model="claude-2.1")
  • "mistral/mistral-tiny" -> LiteLLMGenerator(model="mistral/mistral-tiny")

You can also specify arguments to the generator by comma-separating them:

  • "mistral/mistral-medium,max_tokens=1024"
  • "gpt-4-0613,temperature=0.9,max_tokens=512"
  • "claude-2.1,stop_sequences=Human:;test,max_tokens=100"

(These get parsed as rigging.generator.GenerateParams)

Parameters:

  • identifier (str) –

    The identifier string to use to get a generator.

  • params (GenerateParams | None, default: None ) –

    The generation parameters to use for the generator. These will override any parameters specified in the identifier string.

Returns:

Raises:

  • InvalidModelSpecified

    If the identifier is invalid.

Source code in rigging/generator/base.py
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
    """
    Get a generator by an identifier string. Uses LiteLLM by default.

    Identifier strings are formatted like `<provider>!<model>,<**kwargs>`

    (provider is optional andif not specified)

    Examples:

    - "gpt-3.5-turbo" -> `LiteLLMGenerator(model="gpt-3.5-turbo")`
    - "litellm!claude-2.1" -> `LiteLLMGenerator(model="claude-2.1")`
    - "mistral/mistral-tiny" -> `LiteLLMGenerator(model="mistral/mistral-tiny")`

    You can also specify arguments to the generator by comma-separating them:

    - "mistral/mistral-medium,max_tokens=1024"
    - "gpt-4-0613,temperature=0.9,max_tokens=512"
    - "claude-2.1,stop_sequences=Human:;test,max_tokens=100"

    (These get parsed as [rigging.generator.GenerateParams][])

    Args:
        identifier: The identifier string to use to get a generator.
        params: The generation parameters to use for the generator.
            These will override any parameters specified in the identifier string.

    Returns:
        The generator object.

    Raises:
        InvalidModelSpecified: If the identifier is invalid.
    """

    provider: str = list(g_providers.keys())[0]
    model: str = identifier

    # Split provider, model, and kwargs

    if "!" in identifier:
        try:
            provider, model = identifier.split("!")
        except Exception as e:
            raise InvalidModelSpecifiedError(identifier) from e

    if provider not in g_providers:
        raise InvalidModelSpecifiedError(identifier)

    if not isinstance(g_providers[provider], type):
        lazy_generator = t.cast(LazyGenerator, g_providers[provider])
        g_providers[provider] = lazy_generator()

    generator_cls = t.cast(type[Generator], g_providers[provider])

    kwargs = {}
    if "," in model:
        try:
            model, kwargs_str = model.split(",", 1)
            kwargs = dict(arg.split("=") for arg in kwargs_str.split(","))
        except Exception as e:
            raise InvalidModelSpecifiedError(identifier) from e

    # See if any of the kwargs would apply to the cls constructor directly
    init_signature = inspect.signature(generator_cls)
    init_kwargs: dict[str, t.Any] = {k: kwargs.pop(k) for k in list(kwargs.keys())[:] if k in init_signature.parameters}

    # Do some subtle type conversion
    for k, v in init_kwargs.items():
        try:
            init_kwargs[k] = float(v)
            continue
        except ValueError:
            pass

        try:
            init_kwargs[k] = int(v)
            continue
        except ValueError:
            pass

        if isinstance(v, str) and v.lower() in ["true", "false"]:
            init_kwargs[k] = v.lower() == "true"

    try:
        merged_params = GenerateParams(**kwargs).merge_with(params)
    except Exception as e:
        raise InvalidModelSpecifiedError(identifier) from e

    return generator_cls(model=model, params=merged_params, **init_kwargs)

get_identifier(generator: Generator, params: GenerateParams | None = None) -> str #

Converts the generator instance back into a rigging identifier string.

Warning

The extra parameter field is not currently supported in identifiers.

Parameters:

  • generator (Generator) –

    The generator object.

  • params (GenerateParams | None, default: None ) –

    The generation parameters.

Returns:

  • str

    The identifier string for the generator.

Source code in rigging/generator/base.py
def get_identifier(generator: Generator, params: GenerateParams | None = None) -> str:
    """
    Converts the generator instance back into a rigging identifier string.

    Warning:
        The `extra` parameter field is not currently supported in identifiers.

    Args:
        generator: The generator object.
        params: The generation parameters.

    Returns:
        The identifier string for the generator.
    """

    provider = next(
        name for name, klass in g_providers.items() if isinstance(klass, type) and isinstance(generator, klass)
    )
    identifier = f"{provider}!{generator.model}"

    extra_cls_args = generator.model_dump(exclude_unset=True, exclude={"model", "api_key", "params"})
    if extra_cls_args:
        identifier += f",{','.join([f'{k}={v}' for k, v in extra_cls_args.items()])}"

    merged_params = generator.params.merge_with(params)
    if merged_params.extra:
        logger.debug("Extra parameters are not supported in identifiers.")
        merged_params.extra = {}

    params_dict = merged_params.to_dict()
    if params_dict:
        if "stop" in params_dict:
            params_dict["stop"] = ";".join(params_dict["stop"])
        identifier += f",{','.join([f'{k}={v}' for k, v in params_dict.items()])}"

    return identifier

register_generator(provider: str, generator_cls: type[Generator] | LazyGenerator) -> None #

Register a generator class for a provider id.

This let's you use rigging.generator.get_generator with a custom generator class.

Parameters:

  • provider (str) –

    The name of the provider.

  • generator_cls (type[Generator] | LazyGenerator) –

    The generator class to register.

Source code in rigging/generator/base.py
def register_generator(provider: str, generator_cls: type[Generator] | LazyGenerator) -> None:
    """
    Register a generator class for a provider id.

    This let's you use [rigging.generator.get_generator][] with a custom generator class.

    Args:
        provider: The name of the provider.
        generator_cls: The generator class to register.
    """
    global g_providers
    g_providers[provider] = generator_cls

VLLMGenerator #

Bases: Generator

Generator backed by the vLLM library for local model loading.

Find more information about supported models and formats in their docs.

Warning

The use of VLLM requires the vllm package to be installed directly or by installing rigging as rigging[all].

Note

This generator doesn't leverage any async capabilities.

Note

The model load into memory will occur lazily when the first generation is requested. If you'd want to force this to happen earlier, you can use the .load() method.

To unload, call .unload().

dtype: str = 'auto' class-attribute instance-attribute #

Tensor dtype passed to vllm.LLM

enforce_eager: bool = False class-attribute instance-attribute #

Eager enforcement passed to vllm.LLM

gpu_memory_utilization: float = 0.9 class-attribute instance-attribute #

Memory utilization passed to vllm.LLM

llm: vllm.LLM property #

The underlying vLLM model instance.

quantization: str | None = None class-attribute instance-attribute #

Quantiziation passed to vllm.LLM

trust_remote_code: bool = False class-attribute instance-attribute #

Trust remote code passed to vllm.LLM

from_obj(model: str, llm: vllm.LLM, *, params: GenerateParams | None = None) -> VLLMGenerator classmethod #

Create a generator from an existing vLLM instance.

Parameters:

  • llm (LLM) –

    The vLLM instance to create the generator from.

Returns:

Source code in rigging/generator/vllm_.py
@classmethod
def from_obj(cls, model: str, llm: vllm.LLM, *, params: GenerateParams | None = None) -> VLLMGenerator:
    """Create a generator from an existing vLLM instance.

    Args:
        llm: The vLLM instance to create the generator from.

    Returns:
        The VLLMGenerator instance.
    """
    generator = cls(model=model, params=params or GenerateParams())
    generator._llm = llm
    return generator

DEFAULT_MAX_TOKENS = 1024 module-attribute #

Lifting the default max tokens from transformers

TransformersGenerator #

Bases: Generator

Generator backed by the Transformers library for local model loading.

Warning

The use of Transformers requires the transformers package to be installed directly or by installing rigging as rigging[all].

Warning

The transformers library is expansive with many different models, tokenizers, options, constructors, etc. We do our best to implement a consistent interface, but there may be limitations. Where needed, use .from_obj().

Note

This generator doesn't leverage any async capabilities.

Note

The model load into memory will occur lazily when the first generation is requested. If you'd want to force this to happen earlier, you can use the .load() method.

To unload, call .unload().

device_map: str = 'auto' class-attribute instance-attribute #

llm: AutoModelForCausalLM property #

The underlying AutoModelForCausalLM instance.

load_in_4bit: bool = False class-attribute instance-attribute #

Load in 4 bit passed to AutoModelForCausalLM.from_pretrained

load_in_8bit: bool = False class-attribute instance-attribute #

Load in 8 bit passed to AutoModelForCausalLM.from_pretrained

pipeline: TextGenerationPipeline property #

The underlying TextGenerationPipeline instance.

tokenizer: PreTrainedTokenizer property #

The underlying AutoTokenizer instance.

torch_dtype: str = 'auto' class-attribute instance-attribute #

Torch dtype passed to AutoModelForCausalLM.from_pretrained

trust_remote_code: bool = False class-attribute instance-attribute #

Trust remote code passed to AutoModelForCausalLM.from_pretrained

from_obj(model: str, llm: AutoModelForCausalLM, tokenizer: PreTrainedTokenizer, *, pipeline: TextGenerationPipeline | None = None, params: GenerateParams | None = None) -> TransformersGenerator classmethod #

Create a new instance of TransformersGenerator from an already loaded model and tokenizer.

Parameters:

  • model (str) –

    The loaded model for text generation.

  • tokenizer

    The tokenizer associated with the model.

  • pipeline (TextGenerationPipeline | None, default: None ) –

    The text generation pipeline. Defaults to None.

Returns:

Source code in rigging/generator/transformers_.py
@classmethod
def from_obj(
    cls,
    model: str,
    llm: AutoModelForCausalLM,
    tokenizer: PreTrainedTokenizer,
    *,
    pipeline: TextGenerationPipeline | None = None,
    params: GenerateParams | None = None,
) -> TransformersGenerator:
    """
    Create a new instance of TransformersGenerator from an already loaded model and tokenizer.

    Args:
        model: The loaded model for text generation.
        tokenizer : The tokenizer associated with the model.
        pipeline: The text generation pipeline. Defaults to None.

    Returns:
        The TransformersGenerator instance.
    """
    instance = cls(model=model, params=params or GenerateParams())
    instance._llm = model
    instance._tokenizer = tokenizer
    instance._pipeline = pipeline
    return instance