Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "snet-sdk"
version = "5.2.0"
version = "6.0.0"
description = "SingularityNET Python SDK"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -30,7 +30,8 @@ dependencies = [
"ipfshttpclient==0.4.13.2",
"snet-contracts==1.0.1",
"lighthouseweb3~=0.1.4",
"py-multihash~=3.0"
"py-multihash~=3.0",
"pydantic-settings~=2.13"
]

[tool.poetry.group.dev.dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ ipfshttpclient==0.4.13.2
snet-contracts==1.0.1
lighthouseweb3~=0.1.4
py-multihash~=3.0
pydantic-settings~=2.13
81 changes: 35 additions & 46 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from enum import Enum

import google.protobuf.internal.api_implementation

from google.protobuf import symbol_database as _symbol_database

from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata
Expand All @@ -23,11 +22,17 @@

from snet.contracts import get_contract_object
from snet.sdk.account import Account
from snet.sdk.config import Config
from snet.sdk.config import config
from snet.sdk.client_lib_generator import ClientLibGenerator
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
from snet.sdk.payment_strategies.default_payment_strategy import *
from snet.sdk.payment_strategies import (
DefaultPaymentStrategy,
PaidCallPaymentStrategy,
PrePaidPaymentStrategy,
FreeCallPaymentStrategy,
PaymentStrategy,
)
from snet.sdk.service_client import ServiceClient
from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.custom_typing import ModuleName, ServiceStub
Expand All @@ -52,42 +57,26 @@
class SnetSDK:
"""Base Snet SDK"""

def __init__(self, sdk_config: Config, metadata_provider=None):
self._sdk_config = sdk_config
self._metadata_provider = metadata_provider

# Instantiate Ethereum client
eth_rpc_endpoint = self._sdk_config["eth_rpc_endpoint"]
eth_rpc_request_kwargs = self._sdk_config.get("eth_rpc_request_kwargs")

provider = web3.HTTPProvider(
endpoint_uri=eth_rpc_endpoint, request_kwargs=eth_rpc_request_kwargs
)

self.web3 = web3.Web3(provider)
def __init__(self):
self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT))

# Get MPE contract address from config if specified;
# mostly for local testing
_mpe_contract_address = self._sdk_config.get("mpe_contract_address", None)
if _mpe_contract_address is None:
mpe_contract_address = config.MPE_CONTRACT_ADDRESS
if not mpe_contract_address:
self.mpe_contract = MPEContract(self.web3)
else:
self.mpe_contract = MPEContract(self.web3, _mpe_contract_address)
self.mpe_contract = MPEContract(self.web3, mpe_contract_address)

# Get Registry contract address from config if specified;
# mostly for local testing
_registry_contract_address = self._sdk_config.get("registry_contract_address", None)
if _registry_contract_address is None:
registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS
if registry_contract_address is None:
self.registry_contract = get_contract_object(self.web3, "Registry")
else:
self.registry_contract = get_contract_object(
self.web3, "Registry", _registry_contract_address
self.web3, "Registry", registry_contract_address
)

if self._metadata_provider is None:
self._metadata_provider = StorageProvider(self._sdk_config, self.registry_contract)
self.metadata_provider = StorageProvider(self.registry_contract)

self.account = Account(self.web3, sdk_config, self.mpe_contract)
self.account = Account(self.web3, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)

def create_service_client(
Expand All @@ -103,14 +92,14 @@

# Create and instance of the Config object,
# so we can create an instance of ClientLibGenerator
self.lib_generator = ClientLibGenerator(self._metadata_provider, org_id, service_id)
lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id)

# Download the proto file and generate stubs if needed
force_update = self._sdk_config.get("force_update", False)
force_update = config.FORCE_UPDATE
if force_update:
self.lib_generator.generate_client_library()
lib_generator.generate_client_library()
else:
path_to_pb_files = self.lib_generator.protodir
path_to_pb_files = lib_generator.protodir
pb_2_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2.py", exclude=["training"]
)
Expand All @@ -119,22 +108,22 @@
)
if not pb_2_file_name or not pb_2_grpc_file_name:
print("Generating client library...")
self.lib_generator.generate_client_library()
lib_generator.generate_client_library()

if options is None:
options = dict()
options["concurrency"] = self._sdk_config.get("concurrency", True)
options["concurrency"] = config.CONCURRENCY
options["concurrent_calls"] = concurrent_calls

if payment_strategy is None:
payment_strategy = payment_strategy_type.value()

service_metadata = self._metadata_provider.enhance_service_metadata(org_id, service_id)
service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id)
group = self._get_service_group_details(service_metadata, group_name)

service_stubs = self.get_service_stub()
service_stubs = self.get_service_stub(lib_generator)

pb2_module = self.get_module_by_keyword(keyword="pb2.py")
pb2_module = self.get_module_by_keyword("pb2.py", lib_generator)
_service_client = ServiceClient(
org_id,
service_id,
Expand All @@ -148,14 +137,14 @@
self.web3,
pb2_module,
self.payment_channel_provider,
self.lib_generator.protodir,
self.lib_generator.training_added(),
lib_generator.protodir,
lib_generator.training_added(),
)
return _service_client

def get_service_stub(self) -> list[ServiceStub]:
path_to_pb_files = str(self.lib_generator.protodir)
module_name = self.get_module_by_keyword(keyword="pb2_grpc.py")
def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
path_to_pb_files = str(lib_generator.protodir)
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
sys.path.append(path_to_pb_files)
try:
grpc_file = importlib.import_module(module_name)
Expand All @@ -169,14 +158,14 @@
except Exception as e:
raise Exception(f"Error importing module: {e}")

def get_module_by_keyword(self, keyword: str) -> ModuleName:
path_to_pb_files = self.lib_generator.protodir
def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
path_to_pb_files = lib_generator.protodir
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
module_name = os.path.splitext(file_name)[0]
return ModuleName(module_name)

def get_service_metadata(self, org_id, service_id):
return self._metadata_provider.fetch_service_metadata(org_id, service_id)
return self.metadata_provider.fetch_service_metadata(org_id, service_id)

def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict:
return service_metadata["groups"][0]
Expand All @@ -193,7 +182,7 @@
self, service_metadata: MPEServiceMetadata, group_name: str
) -> dict:
if len(service_metadata["groups"]) == 0:
raise Exception("No Groups found for given service, Please add group to the service")

Check warning on line 185 in snet/sdk/__init__.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/__init__.py#L185

Replace this generic exception class with a more specific one.

if group_name is None:
return self._get_first_group(service_metadata)
Expand Down
26 changes: 14 additions & 12 deletions snet/sdk/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import web3

from snet.contracts import get_contract_object
from snet.sdk.config import Config

from snet.sdk.config import config
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.utils.utils import get_address_from_private, normalize_private_key

Expand All @@ -27,24 +28,25 @@ def __str__(self):


class Account:
def __init__(self, w3: web3.Web3, config: Config, mpe_contract: MPEContract):
self.config: Config = config
self.web3: web3.Web3 = w3
self.mpe_contract: MPEContract = mpe_contract
_token_contract_address: str | None = self.config.get("token_contract_address", None)
if _token_contract_address is None:
def __init__(self, w3: web3.Web3, mpe_contract: MPEContract):
self.web3 = w3
self.mpe_contract = mpe_contract

token_contract_address = config.TOKEN_CONTRACT_ADDRESS
if not token_contract_address:
self.token_contract = get_contract_object(self.web3, "FetchToken")
else:
self.token_contract = get_contract_object(
self.web3, "FetchToken", _token_contract_address
self.web3, "FetchToken", token_contract_address
)

if config.get("private_key") is not None:
self.private_key = normalize_private_key(config.get("private_key"))
if config.get("signer_private_key") is not None:
self.signer_private_key = normalize_private_key(config.get("signer_private_key"))
if config.PRIVATE_KEY:
self.private_key = normalize_private_key(config.PRIVATE_KEY)
if config.SIGNER_PRIVATE_KEY:
self.signer_private_key = normalize_private_key(config.SIGNER_PRIVATE_KEY)
else:
self.signer_private_key = self.private_key

self.address = get_address_from_private(self.private_key)
self.signer_address = get_address_from_private(self.signer_private_key)
self.nonce = 0
Expand Down
86 changes: 48 additions & 38 deletions snet/sdk/config.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
class Config:
def __init__(
self,
private_key,
eth_rpc_endpoint,
wallet_index=0,
ipfs_endpoint=None,
concurrency=True,
force_update=False,
mpe_contract_address=None,
token_contract_address=None,
registry_contract_address=None,
signer_private_key=None,
):
self.__config = {
"private_key": private_key,
"eth_rpc_endpoint": eth_rpc_endpoint,
"wallet_index": wallet_index,
"ipfs_endpoint": (
ipfs_endpoint if ipfs_endpoint else "/dns/ipfs.singularitynet.io/tcp/80/"
),
"concurrency": concurrency,
"force_update": force_update,
"mpe_contract_address": mpe_contract_address,
"token_contract_address": token_contract_address,
"registry_contract_address": registry_contract_address,
"signer_private_key": signer_private_key,
"lighthouse_token": " ",
}

def __getitem__(self, key):
return self.__config[key]

def get(self, key, default=None):
return self.__config.get(key, default)

def get_ipfs_endpoint(self):
return self["ipfs_endpoint"]
from typing import Optional

from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
PRIVATE_KEY: str = ""
SIGNER_PRIVATE_KEY: str = ""
ETH_RPC_ENDPOINT: str = ""
WALLET_INDEX: int = 0
IPFS_ENDPOINT: str = "/dns/ipfs.singularitynet.io/tcp/80/"
CONCURRENCY: bool = True
FORCE_UPDATE: bool = False
MPE_CONTRACT_ADDRESS: str = ""
REGISTRY_CONTRACT_ADDRESS: str = ""
TOKEN_CONTRACT_ADDRESS: str = ""
LIGHTHOUSE_TOKEN: str = " "

model_config = SettingsConfigDict(
env_prefix="SNET_", env_file=".env", env_file_encoding="utf-8", extra="ignore"
)


config = Settings()


def configure(
*,
private_key: Optional[str] = None,
signer_private_key: Optional[str] = None,
eth_rpc_endpoint: Optional[str] = None,
wallet_index: Optional[int] = None,
ipfs_endpoint: Optional[str] = None,
concurrency: Optional[bool] = None,
force_update: Optional[bool] = None,
mpe_contract_address: Optional[str] = None,
registry_contract_address: Optional[str] = None,
token_contract_address: Optional[str] = None,
lighthouse_token: Optional[str] = None,
):
global config
for key, value in locals().items():
key = key.upper()
if hasattr(config, key):
if value is not None:
setattr(config, key, value)
else:
raise ValueError(f"Unknown config key: {key.lower()}")
19 changes: 19 additions & 0 deletions snet/sdk/payment_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from snet.sdk.payment_strategies.freecall_payment_strategy import (
FreeCallPaymentStrategy,
)
from snet.sdk.payment_strategies.paidcall_payment_strategy import (
PaidCallPaymentStrategy,
)
from snet.sdk.payment_strategies.prepaid_payment_strategy import (
PrePaidPaymentStrategy,
)
from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy
from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy

__all__ = [
"PaymentStrategy",
"DefaultPaymentStrategy",
"FreeCallPaymentStrategy",
"PaidCallPaymentStrategy",
"PrePaidPaymentStrategy",
]
7 changes: 4 additions & 3 deletions snet/sdk/storage_provider/storage_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
MPEServiceMetadata,
mpe_service_metadata_from_json,
)
from snet.sdk.config import config


class StorageProvider(object):
def __init__(self, config, registry_contract):
def __init__(self, registry_contract):
self._registry_contract = registry_contract
self._ipfs_client = get_ipfs_client(config)
self.lighthouse_client = Lighthouse(config["lighthouse_token"])
self._ipfs_client = get_ipfs_client()
self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN)

def fetch_org_metadata(self, org_id):
org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0")
Expand Down Expand Up @@ -46,7 +47,7 @@
self._registry_contract.functions.getServiceRegistrationById(org, service).call()
)
if found is not True:
raise Exception(f"No service '{service_id}' found in organization '{org_id}'")

Check warning on line 50 in snet/sdk/storage_provider/storage_provider.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/storage_provider/storage_provider.py#L50

Replace this generic exception class with a more specific one.

service_provider_type, service_metadata_hash = bytesuri_to_hash(s=service_metadata_uri)

Expand Down
7 changes: 4 additions & 3 deletions snet/sdk/utils/ipfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import multihash
import hashlib

from snet.sdk.config import config


def get_from_ipfs_and_checkhash(ipfs_client, ipfs_hash_base58, validate=True):
"""
Expand Down Expand Up @@ -36,6 +38,5 @@ def get_from_ipfs_and_checkhash(ipfs_client, ipfs_hash_base58, validate=True):
return data


def get_ipfs_client(config):
ipfs_endpoint = config.get_ipfs_endpoint()
return ipfshttpclient.connect(ipfs_endpoint)
def get_ipfs_client():
return ipfshttpclient.connect(config.IPFS_ENDPOINT)
Loading