Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ lint:
@ruff check . --fix
@ruff format .
.PHONY: lint

test:
python -m coverage run -m pytest tests/ -v && \
python -m coverage report
.PHONY: test
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
[tool.poetry.group.dev.dependencies]
ruff = "^0.11"
pytest = "^8.3"
coverage = "^7.13"

[tool.ruff]
line-length = 100
Expand Down
61 changes: 16 additions & 45 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import google.protobuf.internal.api_implementation
from google.protobuf import symbol_database as _symbol_database

from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata
from snet.sdk.registry.registry_contract import RegistryContract
from snet.sdk.registry.service_metadata import MPEServiceMetadata

with warnings.catch_warnings():
# Suppress the eth-typing package`s warnings related to some new networks
Expand All @@ -18,9 +19,6 @@
UserWarning,
)

import web3

from snet.contracts import get_contract_object
from snet.sdk.account import Account
from snet.sdk.config import config
from snet.sdk.client_lib_generator import ClientLibGenerator
Expand All @@ -34,12 +32,11 @@
PaymentStrategy,
)
from snet.sdk.service_client import ServiceClient
from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.registry.storage_provider import StorageProvider
from snet.sdk.custom_typing import ModuleName, ServiceStub
from snet.sdk.utils.utils import (
bytes32_to_str,
find_file_by_keyword,
type_converter,
get_we3_object,
)

google.protobuf.internal.api_implementation.Type = lambda: "python"
Expand All @@ -55,29 +52,13 @@ class PaymentStrategyType(Enum):


class SnetSDK:
"""Base Snet SDK"""

def __init__(self):
self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT))

mpe_contract_address = config.MPE_CONTRACT_ADDRESS
if not mpe_contract_address:
self.mpe_contract = MPEContract(self.web3)
else:
self.mpe_contract = MPEContract(self.web3, mpe_contract_address)

registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS
if registry_contract_address is None:
self.registry_contract = get_contract_object(self.web3, "Registry")
else:
self.registry_contract = get_contract_object(
self.web3, "Registry", registry_contract_address
)

self.w3 = get_we3_object()
self.mpe_contract = MPEContract()
self.registry_contract = RegistryContract()
self.metadata_provider = StorageProvider(self.registry_contract)

self.account = Account(self.web3, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract)
self.account = Account()

def create_service_client(
self,
Expand All @@ -99,7 +80,7 @@ def create_service_client(
if force_update:
lib_generator.generate_client_library()
else:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
pb_2_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2.py", exclude=["training"]
)
Expand Down Expand Up @@ -134,16 +115,16 @@ def create_service_client(
options,
self.mpe_contract,
self.account,
self.web3,
self.w3,
pb2_module,
self.payment_channel_provider,
lib_generator.protodir,
lib_generator.proto_dir,
lib_generator.training_added(),
)
return _service_client

def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
path_to_pb_files = str(lib_generator.protodir)
path_to_pb_files = str(lib_generator.proto_dir)
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
sys.path.append(path_to_pb_files)
try:
Expand All @@ -159,7 +140,7 @@ def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStu
raise Exception(f"Error importing module: {e}")

def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
module_name = os.path.splitext(file_name)[0]
return ModuleName(module_name)
Expand Down Expand Up @@ -190,17 +171,7 @@ def _get_service_group_details(
return self._get_group_by_group_name(service_metadata, group_name)

def get_organization_list(self) -> list:
org_list = self.registry_contract.functions.listOrganizations().call()
organization_list = []
for idx, org_id in enumerate(org_list):
organization_list.append(bytes32_to_str(org_id))
return organization_list
return self.registry_contract.list_orgs()

def get_services_list(self, org_id: str) -> list:
found, org_service_list = self.registry_contract.functions.listServicesForOrganization(
type_converter("bytes32")(org_id)
).call()
if not found:
raise Exception(f"Organization with id={org_id} doesn't exist!")
org_service_list = list(map(bytes32_to_str, org_service_list))
return org_service_list
return self.registry_contract.list_service_for_org(org_id)
53 changes: 17 additions & 36 deletions snet/sdk/account.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json

import web3

from snet.contracts import get_contract_object

from snet.sdk.config import config
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.utils.utils import get_address_from_private, normalize_private_key
from snet.sdk.utils.utils import get_address_from_private, normalize_private_key, get_we3_object

DEFAULT_GAS = 300000
TRANSACTION_TIMEOUT = 500
Expand All @@ -28,17 +26,13 @@ def __str__(self):


class Account:
def __init__(self, w3: web3.Web3, mpe_contract: MPEContract):
self.web3 = w3
self.mpe_contract = mpe_contract
def __init__(self):
self.w3 = get_we3_object()
self.mpe_address = config.MPE_CONTRACT_ADDRESS

token_contract_address = config.TOKEN_CONTRACT_ADDRESS
if not token_contract_address:
self.token_contract = get_contract_object(self.web3, "FetchToken")
else:
self.token_contract = get_contract_object(
self.web3, "FetchToken", token_contract_address
)
self.token_contract = get_contract_object(
self.w3, "FetchToken", config.TOKEN_CONTRACT_ADDRESS
)

if config.PRIVATE_KEY:
self.private_key = normalize_private_key(config.PRIVATE_KEY)
Expand All @@ -52,19 +46,19 @@ def __init__(self, w3: web3.Web3, mpe_contract: MPEContract):
self.nonce = 0

def _get_nonce(self):
nonce = self.web3.eth.get_transaction_count(self.address)
nonce = self.w3.eth.get_transaction_count(self.address)
if self.nonce >= nonce:
nonce = self.nonce + 1
self.nonce = nonce
return nonce

def _get_gas_price(self):
gas_price = self.web3.eth.gas_price
gas_price = self.w3.eth.gas_price
if gas_price <= 15000000000:
gas_price += gas_price * 1 / 3
elif gas_price > 15000000000 and gas_price <= 50000000000:
elif 15000000000 < gas_price <= 50000000000:
gas_price += gas_price * 1 / 5
elif gas_price > 50000000000 and gas_price <= 150000000000:
elif 50000000000 < gas_price <= 150000000000:
gas_price += 7000000000
elif gas_price > 150000000000:
gas_price += gas_price * 1 / 10
Expand All @@ -73,44 +67,31 @@ def _get_gas_price(self):
def _send_signed_transaction(self, contract_fn, *args):
transaction = contract_fn(*args).build_transaction(
{
"chainId": int(self.web3.net.version),
"chainId": int(self.w3.net.version),
"gas": DEFAULT_GAS,
"gasPrice": self._get_gas_price(),
"nonce": self._get_nonce(),
}
)
signed_txn = self.web3.eth.account.sign_transaction(
transaction, private_key=self.private_key
)
return self.web3.to_hex(self.web3.eth.send_raw_transaction(signed_txn.raw_transaction))
signed_txn = self.w3.eth.account.sign_transaction(transaction, private_key=self.private_key)
return self.w3.to_hex(self.w3.eth.send_raw_transaction(signed_txn.raw_transaction))

def send_transaction(self, contract_fn, *args):
txn_hash = self._send_signed_transaction(contract_fn, *args)
return self.web3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT)
return self.w3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT)

def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder):
if receipt.status == 0:
raise TransactionError("Transaction failed", receipt)
else:
return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), cls=encoder)

def escrow_balance(self):
return self.mpe_contract.balance(self.address)

def deposit_to_escrow_account(self, amount_in_cogs):
already_approved = self.allowance()
if amount_in_cogs > already_approved:
self.approve_transfer(amount_in_cogs)
return self.mpe_contract.deposit(self, amount_in_cogs)

def approve_transfer(self, amount_in_cogs):
return self.send_transaction(
self.token_contract.functions.approve,
self.mpe_contract.contract.address,
self.mpe_address,
amount_in_cogs,
)

def allowance(self):
return self.token_contract.functions.allowance(
self.address, self.mpe_contract.contract.address
).call()
return self.token_contract.functions.allowance(self.address, self.mpe_address).call()
28 changes: 14 additions & 14 deletions snet/sdk/client_lib_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path

from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.registry.storage_provider import StorageProvider
from snet.sdk.utils.utils import compile_proto


Expand All @@ -11,41 +11,41 @@ def __init__(
metadata_provider: StorageProvider,
org_id: str,
service_id: str,
protodir: Path | None = None,
proto_dir: Path | None = None,
):
self._metadata_provider: StorageProvider = metadata_provider
self.org_id: str = org_id
self.service_id: str = service_id
self.language: str = "python"
self.protodir: Path = protodir if protodir else Path.home().joinpath(".snet")
self.proto_dir: Path = proto_dir if proto_dir else Path.home().joinpath(".snet")
self.generate_directories_by_params()

def generate_client_library(self) -> None:
try:
self.receive_proto_files()
compilation_result = compile_proto(
entry_path=self.protodir,
codegen_dir=self.protodir,
entry_path=self.proto_dir,
codegen_dir=self.proto_dir,
target_language=self.language,
add_training=self.training_added(),
)
if compilation_result:
print(
f'client libraries for service with id "{self.service_id}" '
f'in org with id "{self.org_id}" '
f"generated at {self.protodir}"
f"generated at {self.proto_dir}"
)
except Exception as e:
print(str(e))

def generate_directories_by_params(self) -> None:
if not self.protodir.is_absolute():
self.protodir = Path.cwd().joinpath(self.protodir)
if not self.proto_dir.is_absolute():
self.proto_dir = Path.cwd().joinpath(self.proto_dir)
self.create_service_client_libraries_path()

def create_service_client_libraries_path(self) -> None:
self.protodir = self.protodir.joinpath(self.org_id, self.service_id, self.language)
self.protodir.mkdir(parents=True, exist_ok=True)
self.proto_dir = self.proto_dir.joinpath(self.org_id, self.service_id, self.language)
self.proto_dir.mkdir(parents=True, exist_ok=True)

def receive_proto_files(self) -> None:
metadata = self._metadata_provider.fetch_service_metadata(
Expand All @@ -54,17 +54,17 @@ def receive_proto_files(self) -> None:
service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash")

# Receive proto files
if self.protodir.exists():
self._metadata_provider.fetch_and_extract_proto(service_api_source, self.protodir)
if self.proto_dir.exists():
self._metadata_provider.fetch_and_extract_proto(service_api_source, self.proto_dir)
else:
raise Exception("Directory for storing proto files is not found")

def training_added(self) -> bool:
files = os.listdir(self.protodir)
files = os.listdir(self.proto_dir)
for file in files:
if ".proto" not in file:
continue
with open(self.protodir.joinpath(file), "r") as f:
with open(self.proto_dir.joinpath(file), "r") as f:
proto_text = f.read()
if 'import "training.proto";' in proto_text:
return True
Expand Down
Loading
Loading