diff --git a/pyproject.toml b/pyproject.toml index 5c00c86..f2c6551 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "snet-sdk" -version = "5.2.0" +version = "6.0.0" description = "SingularityNET Python SDK" readme = "README.md" requires-python = ">=3.10" @@ -30,7 +30,8 @@ dependencies = [ "ipfshttpclient==0.4.13.2", "snet-contracts==1.0.1", "lighthouseweb3~=0.1.4", - "py-multihash~=3.0" + "py-multihash~=3.0", + "pydantic-settings~=2.13" ] [tool.poetry.group.dev.dependencies] diff --git a/requirements.txt b/requirements.txt index 672ba27..4bc7135 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ ipfshttpclient==0.4.13.2 snet-contracts==1.0.1 lighthouseweb3~=0.1.4 py-multihash~=3.0 +pydantic-settings~=2.13 \ No newline at end of file diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 9bf25af..e574fa0 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -5,7 +5,6 @@ from enum import Enum import google.protobuf.internal.api_implementation - from google.protobuf import symbol_database as _symbol_database from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata @@ -23,11 +22,17 @@ from snet.contracts import get_contract_object from snet.sdk.account import Account -from snet.sdk.config import Config +from snet.sdk.config import config from snet.sdk.client_lib_generator import ClientLibGenerator from snet.sdk.mpe.mpe_contract import MPEContract from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider -from snet.sdk.payment_strategies.default_payment_strategy import * +from snet.sdk.payment_strategies import ( + DefaultPaymentStrategy, + PaidCallPaymentStrategy, + PrePaidPaymentStrategy, + FreeCallPaymentStrategy, + PaymentStrategy, +) from snet.sdk.service_client import ServiceClient from snet.sdk.storage_provider.storage_provider import StorageProvider from snet.sdk.custom_typing import ModuleName, ServiceStub @@ -52,42 +57,26 @@ class PaymentStrategyType(Enum): class SnetSDK: """Base Snet SDK""" - def __init__(self, sdk_config: Config, metadata_provider=None): - self._sdk_config = sdk_config - self._metadata_provider = metadata_provider - - # Instantiate Ethereum client - eth_rpc_endpoint = self._sdk_config["eth_rpc_endpoint"] - eth_rpc_request_kwargs = self._sdk_config.get("eth_rpc_request_kwargs") - - provider = web3.HTTPProvider( - endpoint_uri=eth_rpc_endpoint, request_kwargs=eth_rpc_request_kwargs - ) - - self.web3 = web3.Web3(provider) + def __init__(self): + self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT)) - # Get MPE contract address from config if specified; - # mostly for local testing - _mpe_contract_address = self._sdk_config.get("mpe_contract_address", None) - if _mpe_contract_address is None: + mpe_contract_address = config.MPE_CONTRACT_ADDRESS + if not mpe_contract_address: self.mpe_contract = MPEContract(self.web3) else: - self.mpe_contract = MPEContract(self.web3, _mpe_contract_address) + self.mpe_contract = MPEContract(self.web3, mpe_contract_address) - # Get Registry contract address from config if specified; - # mostly for local testing - _registry_contract_address = self._sdk_config.get("registry_contract_address", None) - if _registry_contract_address is None: + registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS + if registry_contract_address is None: self.registry_contract = get_contract_object(self.web3, "Registry") else: self.registry_contract = get_contract_object( - self.web3, "Registry", _registry_contract_address + self.web3, "Registry", registry_contract_address ) - if self._metadata_provider is None: - self._metadata_provider = StorageProvider(self._sdk_config, self.registry_contract) + self.metadata_provider = StorageProvider(self.registry_contract) - self.account = Account(self.web3, sdk_config, self.mpe_contract) + self.account = Account(self.web3, self.mpe_contract) self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) def create_service_client( @@ -103,14 +92,14 @@ def create_service_client( # Create and instance of the Config object, # so we can create an instance of ClientLibGenerator - self.lib_generator = ClientLibGenerator(self._metadata_provider, org_id, service_id) + lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id) # Download the proto file and generate stubs if needed - force_update = self._sdk_config.get("force_update", False) + force_update = config.FORCE_UPDATE if force_update: - self.lib_generator.generate_client_library() + lib_generator.generate_client_library() else: - path_to_pb_files = self.lib_generator.protodir + path_to_pb_files = lib_generator.protodir pb_2_file_name = find_file_by_keyword( path_to_pb_files, keyword="pb2.py", exclude=["training"] ) @@ -119,22 +108,22 @@ def create_service_client( ) if not pb_2_file_name or not pb_2_grpc_file_name: print("Generating client library...") - self.lib_generator.generate_client_library() + lib_generator.generate_client_library() if options is None: options = dict() - options["concurrency"] = self._sdk_config.get("concurrency", True) + options["concurrency"] = config.CONCURRENCY options["concurrent_calls"] = concurrent_calls if payment_strategy is None: payment_strategy = payment_strategy_type.value() - service_metadata = self._metadata_provider.enhance_service_metadata(org_id, service_id) + service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id) group = self._get_service_group_details(service_metadata, group_name) - service_stubs = self.get_service_stub() + service_stubs = self.get_service_stub(lib_generator) - pb2_module = self.get_module_by_keyword(keyword="pb2.py") + pb2_module = self.get_module_by_keyword("pb2.py", lib_generator) _service_client = ServiceClient( org_id, service_id, @@ -148,14 +137,14 @@ def create_service_client( self.web3, pb2_module, self.payment_channel_provider, - self.lib_generator.protodir, - self.lib_generator.training_added(), + lib_generator.protodir, + lib_generator.training_added(), ) return _service_client - def get_service_stub(self) -> list[ServiceStub]: - path_to_pb_files = str(self.lib_generator.protodir) - module_name = self.get_module_by_keyword(keyword="pb2_grpc.py") + def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]: + path_to_pb_files = str(lib_generator.protodir) + module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator) sys.path.append(path_to_pb_files) try: grpc_file = importlib.import_module(module_name) @@ -169,14 +158,14 @@ def get_service_stub(self) -> list[ServiceStub]: except Exception as e: raise Exception(f"Error importing module: {e}") - def get_module_by_keyword(self, keyword: str) -> ModuleName: - path_to_pb_files = self.lib_generator.protodir + def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName: + path_to_pb_files = lib_generator.protodir file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"]) module_name = os.path.splitext(file_name)[0] return ModuleName(module_name) def get_service_metadata(self, org_id, service_id): - return self._metadata_provider.fetch_service_metadata(org_id, service_id) + return self.metadata_provider.fetch_service_metadata(org_id, service_id) def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict: return service_metadata["groups"][0] diff --git a/snet/sdk/account.py b/snet/sdk/account.py index 34aa4b0..31b404d 100644 --- a/snet/sdk/account.py +++ b/snet/sdk/account.py @@ -3,7 +3,8 @@ import web3 from snet.contracts import get_contract_object -from snet.sdk.config import Config + +from snet.sdk.config import config from snet.sdk.mpe.mpe_contract import MPEContract from snet.sdk.utils.utils import get_address_from_private, normalize_private_key @@ -27,24 +28,25 @@ def __str__(self): class Account: - def __init__(self, w3: web3.Web3, config: Config, mpe_contract: MPEContract): - self.config: Config = config - self.web3: web3.Web3 = w3 - self.mpe_contract: MPEContract = mpe_contract - _token_contract_address: str | None = self.config.get("token_contract_address", None) - if _token_contract_address is None: + def __init__(self, w3: web3.Web3, mpe_contract: MPEContract): + self.web3 = w3 + self.mpe_contract = mpe_contract + + token_contract_address = config.TOKEN_CONTRACT_ADDRESS + if not token_contract_address: self.token_contract = get_contract_object(self.web3, "FetchToken") else: self.token_contract = get_contract_object( - self.web3, "FetchToken", _token_contract_address + self.web3, "FetchToken", token_contract_address ) - if config.get("private_key") is not None: - self.private_key = normalize_private_key(config.get("private_key")) - if config.get("signer_private_key") is not None: - self.signer_private_key = normalize_private_key(config.get("signer_private_key")) + if config.PRIVATE_KEY: + self.private_key = normalize_private_key(config.PRIVATE_KEY) + if config.SIGNER_PRIVATE_KEY: + self.signer_private_key = normalize_private_key(config.SIGNER_PRIVATE_KEY) else: self.signer_private_key = self.private_key + self.address = get_address_from_private(self.private_key) self.signer_address = get_address_from_private(self.signer_private_key) self.nonce = 0 diff --git a/snet/sdk/config.py b/snet/sdk/config.py index 8b253f3..e205ad7 100644 --- a/snet/sdk/config.py +++ b/snet/sdk/config.py @@ -1,38 +1,48 @@ -class Config: - def __init__( - self, - private_key, - eth_rpc_endpoint, - wallet_index=0, - ipfs_endpoint=None, - concurrency=True, - force_update=False, - mpe_contract_address=None, - token_contract_address=None, - registry_contract_address=None, - signer_private_key=None, - ): - self.__config = { - "private_key": private_key, - "eth_rpc_endpoint": eth_rpc_endpoint, - "wallet_index": wallet_index, - "ipfs_endpoint": ( - ipfs_endpoint if ipfs_endpoint else "/dns/ipfs.singularitynet.io/tcp/80/" - ), - "concurrency": concurrency, - "force_update": force_update, - "mpe_contract_address": mpe_contract_address, - "token_contract_address": token_contract_address, - "registry_contract_address": registry_contract_address, - "signer_private_key": signer_private_key, - "lighthouse_token": " ", - } - - def __getitem__(self, key): - return self.__config[key] - - def get(self, key, default=None): - return self.__config.get(key, default) - - def get_ipfs_endpoint(self): - return self["ipfs_endpoint"] +from typing import Optional + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + PRIVATE_KEY: str = "" + SIGNER_PRIVATE_KEY: str = "" + ETH_RPC_ENDPOINT: str = "" + WALLET_INDEX: int = 0 + IPFS_ENDPOINT: str = "/dns/ipfs.singularitynet.io/tcp/80/" + CONCURRENCY: bool = True + FORCE_UPDATE: bool = False + MPE_CONTRACT_ADDRESS: str = "" + REGISTRY_CONTRACT_ADDRESS: str = "" + TOKEN_CONTRACT_ADDRESS: str = "" + LIGHTHOUSE_TOKEN: str = " " + + model_config = SettingsConfigDict( + env_prefix="SNET_", env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) + + +config = Settings() + + +def configure( + *, + private_key: Optional[str] = None, + signer_private_key: Optional[str] = None, + eth_rpc_endpoint: Optional[str] = None, + wallet_index: Optional[int] = None, + ipfs_endpoint: Optional[str] = None, + concurrency: Optional[bool] = None, + force_update: Optional[bool] = None, + mpe_contract_address: Optional[str] = None, + registry_contract_address: Optional[str] = None, + token_contract_address: Optional[str] = None, + lighthouse_token: Optional[str] = None, +): + global config + for key, value in locals().items(): + key = key.upper() + if hasattr(config, key): + if value is not None: + setattr(config, key, value) + else: + raise ValueError(f"Unknown config key: {key.lower()}") diff --git a/snet/sdk/payment_strategies/__init__.py b/snet/sdk/payment_strategies/__init__.py index e69de29..915ca31 100644 --- a/snet/sdk/payment_strategies/__init__.py +++ b/snet/sdk/payment_strategies/__init__.py @@ -0,0 +1,19 @@ +from snet.sdk.payment_strategies.freecall_payment_strategy import ( + FreeCallPaymentStrategy, +) +from snet.sdk.payment_strategies.paidcall_payment_strategy import ( + PaidCallPaymentStrategy, +) +from snet.sdk.payment_strategies.prepaid_payment_strategy import ( + PrePaidPaymentStrategy, +) +from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy +from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy + +__all__ = [ + "PaymentStrategy", + "DefaultPaymentStrategy", + "FreeCallPaymentStrategy", + "PaidCallPaymentStrategy", + "PrePaidPaymentStrategy", +] diff --git a/snet/sdk/storage_provider/storage_provider.py b/snet/sdk/storage_provider/storage_provider.py index c2196d6..edfdec1 100644 --- a/snet/sdk/storage_provider/storage_provider.py +++ b/snet/sdk/storage_provider/storage_provider.py @@ -11,13 +11,14 @@ MPEServiceMetadata, mpe_service_metadata_from_json, ) +from snet.sdk.config import config class StorageProvider(object): - def __init__(self, config, registry_contract): + def __init__(self, registry_contract): self._registry_contract = registry_contract - self._ipfs_client = get_ipfs_client(config) - self.lighthouse_client = Lighthouse(config["lighthouse_token"]) + self._ipfs_client = get_ipfs_client() + self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) def fetch_org_metadata(self, org_id): org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") diff --git a/snet/sdk/utils/ipfs_utils.py b/snet/sdk/utils/ipfs_utils.py index 9a00d64..a3e76bd 100644 --- a/snet/sdk/utils/ipfs_utils.py +++ b/snet/sdk/utils/ipfs_utils.py @@ -2,6 +2,8 @@ import multihash import hashlib +from snet.sdk.config import config + def get_from_ipfs_and_checkhash(ipfs_client, ipfs_hash_base58, validate=True): """ @@ -36,6 +38,5 @@ def get_from_ipfs_and_checkhash(ipfs_client, ipfs_hash_base58, validate=True): return data -def get_ipfs_client(config): - ipfs_endpoint = config.get_ipfs_endpoint() - return ipfshttpclient.connect(ipfs_endpoint) +def get_ipfs_client(): + return ipfshttpclient.connect(config.IPFS_ENDPOINT)