diff --git a/textgrad/variable.py b/textgrad/variable.py index 01b9020..91507d3 100644 --- a/textgrad/variable.py +++ b/textgrad/variable.py @@ -2,6 +2,7 @@ from textgrad.engine import EngineLM from typing import List, Set, Dict import httpx +import numpy as np from collections import defaultdict from functools import partial from .config import SingletonBackwardEngine @@ -11,7 +12,7 @@ class Variable: def __init__( self, - value: Union[str, bytes] = "", + value: Union[str, bytes, np.integer] = "", image_path: str = "", predecessors: List['Variable']=None, requires_grad: bool=True, @@ -33,14 +34,14 @@ def __init__( if predecessors is None: predecessors = [] - + _predecessor_requires_grad = [v for v in predecessors if v.requires_grad] - + if (not requires_grad) and (len(_predecessor_requires_grad) > 0): raise Exception("If the variable does not require grad, none of its predecessors should require grad." f"In this case, following predecessors require grad: {_predecessor_requires_grad}") - - assert type(value) in [str, bytes, int], "Value must be a string, int, or image (bytes). Got: {}".format(type(value)) + + assert type(value) in [str, bytes, int] or np.issubdtype(value, np.integer), "Value must be a string, int, or image (bytes). Got: {}".format(type(value)) if isinstance(value, int): value = str(value) # We'll currently let "empty variables" slide, but we'll need to handle this better in the future.